1use crate::data::datatable::{DataTable, DataValue};
2use crate::sql::recursive_parser::SqlExpression;
3use anyhow::{anyhow, Result};
4use chrono::{DateTime, Datelike, NaiveDate, NaiveDateTime, TimeZone, Utc};
5use tracing::debug;
6
7pub struct ArithmeticEvaluator<'a> {
10 table: &'a DataTable,
11}
12
13impl<'a> ArithmeticEvaluator<'a> {
14 pub fn new(table: &'a DataTable) -> Self {
15 Self { table }
16 }
17
18 fn find_similar_column(&self, name: &str) -> Option<String> {
20 let columns = self.table.column_names();
21 let mut best_match: Option<(String, usize)> = None;
22
23 for col in columns {
24 let distance = self.edit_distance(&col.to_lowercase(), &name.to_lowercase());
25 let max_distance = if name.len() > 10 { 3 } else { 2 };
28 if distance <= max_distance {
29 match &best_match {
30 None => best_match = Some((col, distance)),
31 Some((_, best_dist)) if distance < *best_dist => {
32 best_match = Some((col, distance));
33 }
34 _ => {}
35 }
36 }
37 }
38
39 best_match.map(|(name, _)| name)
40 }
41
42 fn edit_distance(&self, s1: &str, s2: &str) -> usize {
44 let len1 = s1.len();
45 let len2 = s2.len();
46 let mut matrix = vec![vec![0; len2 + 1]; len1 + 1];
47
48 for i in 0..=len1 {
49 matrix[i][0] = i;
50 }
51 for j in 0..=len2 {
52 matrix[0][j] = j;
53 }
54
55 for (i, c1) in s1.chars().enumerate() {
56 for (j, c2) in s2.chars().enumerate() {
57 let cost = if c1 == c2 { 0 } else { 1 };
58 matrix[i + 1][j + 1] = std::cmp::min(
59 matrix[i][j + 1] + 1, std::cmp::min(
61 matrix[i + 1][j] + 1, matrix[i][j] + cost, ),
64 );
65 }
66 }
67
68 matrix[len1][len2]
69 }
70
71 pub fn evaluate(&self, expr: &SqlExpression, row_index: usize) -> Result<DataValue> {
73 debug!(
74 "ArithmeticEvaluator: evaluating {:?} for row {}",
75 expr, row_index
76 );
77
78 match expr {
79 SqlExpression::Column(column_name) => self.evaluate_column(column_name, row_index),
80 SqlExpression::StringLiteral(s) => Ok(DataValue::String(s.clone())),
81 SqlExpression::NumberLiteral(n) => self.evaluate_number_literal(n),
82 SqlExpression::BinaryOp { left, op, right } => {
83 self.evaluate_binary_op(left, op, right, row_index)
84 }
85 SqlExpression::FunctionCall { name, args } => {
86 self.evaluate_function(name, args, row_index)
87 }
88 _ => Err(anyhow!(
89 "Unsupported expression type for arithmetic evaluation: {:?}",
90 expr
91 )),
92 }
93 }
94
95 fn evaluate_column(&self, column_name: &str, row_index: usize) -> Result<DataValue> {
97 let col_index = self.table.get_column_index(column_name).ok_or_else(|| {
98 let suggestion = self.find_similar_column(column_name);
99 match suggestion {
100 Some(similar) => anyhow!(
101 "Column '{}' not found. Did you mean '{}'?",
102 column_name,
103 similar
104 ),
105 None => anyhow!("Column '{}' not found", column_name),
106 }
107 })?;
108
109 if row_index >= self.table.row_count() {
110 return Err(anyhow!("Row index {} out of bounds", row_index));
111 }
112
113 let row = self
114 .table
115 .get_row(row_index)
116 .ok_or_else(|| anyhow!("Row {} not found", row_index))?;
117
118 let value = row
119 .get(col_index)
120 .ok_or_else(|| anyhow!("Column index {} out of bounds for row", col_index))?;
121
122 Ok(value.clone())
123 }
124
125 fn evaluate_number_literal(&self, number_str: &str) -> Result<DataValue> {
127 if let Ok(int_val) = number_str.parse::<i64>() {
129 return Ok(DataValue::Integer(int_val));
130 }
131
132 if let Ok(float_val) = number_str.parse::<f64>() {
134 return Ok(DataValue::Float(float_val));
135 }
136
137 Err(anyhow!("Invalid number literal: {}", number_str))
138 }
139
140 fn evaluate_binary_op(
142 &self,
143 left: &SqlExpression,
144 op: &str,
145 right: &SqlExpression,
146 row_index: usize,
147 ) -> Result<DataValue> {
148 let left_val = self.evaluate(left, row_index)?;
149 let right_val = self.evaluate(right, row_index)?;
150
151 debug!(
152 "ArithmeticEvaluator: {} {} {}",
153 self.format_value(&left_val),
154 op,
155 self.format_value(&right_val)
156 );
157
158 match op {
159 "+" => self.add_values(&left_val, &right_val),
160 "-" => self.subtract_values(&left_val, &right_val),
161 "*" => self.multiply_values(&left_val, &right_val),
162 "/" => self.divide_values(&left_val, &right_val),
163 _ => Err(anyhow!("Unsupported arithmetic operator: {}", op)),
164 }
165 }
166
167 fn add_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
169 match (left, right) {
170 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a + b)),
171 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 + b)),
172 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a + *b as f64)),
173 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a + b)),
174 _ => Err(anyhow!("Cannot add {:?} and {:?}", left, right)),
175 }
176 }
177
178 fn subtract_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
180 match (left, right) {
181 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a - b)),
182 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 - b)),
183 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a - *b as f64)),
184 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a - b)),
185 _ => Err(anyhow!("Cannot subtract {:?} and {:?}", left, right)),
186 }
187 }
188
189 fn multiply_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
191 match (left, right) {
192 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a * b)),
193 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 * b)),
194 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a * *b as f64)),
195 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a * b)),
196 _ => Err(anyhow!("Cannot multiply {:?} and {:?}", left, right)),
197 }
198 }
199
200 fn divide_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
202 let is_zero = match right {
204 DataValue::Integer(0) => true,
205 DataValue::Float(f) if f.abs() < f64::EPSILON => true,
206 _ => false,
207 };
208
209 if is_zero {
210 return Err(anyhow!("Division by zero"));
211 }
212
213 match (left, right) {
214 (DataValue::Integer(a), DataValue::Integer(b)) => {
215 if a % b == 0 {
217 Ok(DataValue::Integer(a / b))
218 } else {
219 Ok(DataValue::Float(*a as f64 / *b as f64))
220 }
221 }
222 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 / b)),
223 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a / *b as f64)),
224 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a / b)),
225 _ => Err(anyhow!("Cannot divide {:?} and {:?}", left, right)),
226 }
227 }
228
229 fn format_value(&self, value: &DataValue) -> String {
231 match value {
232 DataValue::Integer(i) => i.to_string(),
233 DataValue::Float(f) => f.to_string(),
234 DataValue::String(s) => format!("'{}'", s),
235 _ => format!("{:?}", value),
236 }
237 }
238
239 fn evaluate_function(
241 &self,
242 name: &str,
243 args: &[SqlExpression],
244 row_index: usize,
245 ) -> Result<DataValue> {
246 match name {
247 "ROUND" => {
248 if args.is_empty() || args.len() > 2 {
249 return Err(anyhow!("ROUND requires 1 or 2 arguments"));
250 }
251
252 let value = self.evaluate(&args[0], row_index)?;
254
255 let decimals = if args.len() == 2 {
257 match self.evaluate(&args[1], row_index)? {
258 DataValue::Integer(n) => n as i32,
259 DataValue::Float(f) => f as i32,
260 _ => return Err(anyhow!("ROUND precision must be a number")),
261 }
262 } else {
263 0
264 };
265
266 match value {
268 DataValue::Integer(n) => Ok(DataValue::Integer(n)), DataValue::Float(f) => {
270 if decimals >= 0 {
271 let multiplier = 10_f64.powi(decimals);
272 let rounded = (f * multiplier).round() / multiplier;
273 if decimals == 0 {
274 Ok(DataValue::Integer(rounded as i64))
276 } else {
277 Ok(DataValue::Float(rounded))
278 }
279 } else {
280 let divisor = 10_f64.powi(-decimals);
282 let rounded = (f / divisor).round() * divisor;
283 Ok(DataValue::Float(rounded))
284 }
285 }
286 _ => Err(anyhow!("ROUND can only be applied to numeric values")),
287 }
288 }
289 "ABS" => {
290 if args.len() != 1 {
291 return Err(anyhow!("ABS requires exactly 1 argument"));
292 }
293
294 let value = self.evaluate(&args[0], row_index)?;
295 match value {
296 DataValue::Integer(n) => Ok(DataValue::Integer(n.abs())),
297 DataValue::Float(f) => Ok(DataValue::Float(f.abs())),
298 _ => Err(anyhow!("ABS can only be applied to numeric values")),
299 }
300 }
301 "FLOOR" => {
302 if args.len() != 1 {
303 return Err(anyhow!("FLOOR requires exactly 1 argument"));
304 }
305
306 let value = self.evaluate(&args[0], row_index)?;
307 match value {
308 DataValue::Integer(n) => Ok(DataValue::Integer(n)),
309 DataValue::Float(f) => Ok(DataValue::Integer(f.floor() as i64)),
310 _ => Err(anyhow!("FLOOR can only be applied to numeric values")),
311 }
312 }
313 "CEILING" | "CEIL" => {
314 if args.len() != 1 {
315 return Err(anyhow!("CEILING requires exactly 1 argument"));
316 }
317
318 let value = self.evaluate(&args[0], row_index)?;
319 match value {
320 DataValue::Integer(n) => Ok(DataValue::Integer(n)),
321 DataValue::Float(f) => Ok(DataValue::Integer(f.ceil() as i64)),
322 _ => Err(anyhow!("CEILING can only be applied to numeric values")),
323 }
324 }
325 "MOD" => {
326 if args.len() != 2 {
327 return Err(anyhow!("MOD requires exactly 2 arguments"));
328 }
329
330 let dividend = self.evaluate(&args[0], row_index)?;
331 let divisor = self.evaluate(&args[1], row_index)?;
332
333 match (÷nd, &divisor) {
334 (DataValue::Integer(n), DataValue::Integer(d)) => {
335 if *d == 0 {
336 return Err(anyhow!("Division by zero in MOD"));
337 }
338 Ok(DataValue::Integer(n % d))
339 }
340 _ => {
341 let n = match dividend {
343 DataValue::Integer(i) => i as f64,
344 DataValue::Float(f) => f,
345 _ => return Err(anyhow!("MOD requires numeric arguments")),
346 };
347 let d = match divisor {
348 DataValue::Integer(i) => i as f64,
349 DataValue::Float(f) => f,
350 _ => return Err(anyhow!("MOD requires numeric arguments")),
351 };
352 if d == 0.0 {
353 return Err(anyhow!("Division by zero in MOD"));
354 }
355 Ok(DataValue::Float(n % d))
356 }
357 }
358 }
359 "QUOTIENT" => {
360 if args.len() != 2 {
361 return Err(anyhow!("QUOTIENT requires exactly 2 arguments"));
362 }
363
364 let numerator = self.evaluate(&args[0], row_index)?;
365 let denominator = self.evaluate(&args[1], row_index)?;
366
367 match (&numerator, &denominator) {
368 (DataValue::Integer(n), DataValue::Integer(d)) => {
369 if *d == 0 {
370 return Err(anyhow!("Division by zero in QUOTIENT"));
371 }
372 Ok(DataValue::Integer(n / d))
373 }
374 _ => {
375 let n = match numerator {
377 DataValue::Integer(i) => i as f64,
378 DataValue::Float(f) => f,
379 _ => return Err(anyhow!("QUOTIENT requires numeric arguments")),
380 };
381 let d = match denominator {
382 DataValue::Integer(i) => i as f64,
383 DataValue::Float(f) => f,
384 _ => return Err(anyhow!("QUOTIENT requires numeric arguments")),
385 };
386 if d == 0.0 {
387 return Err(anyhow!("Division by zero in QUOTIENT"));
388 }
389 Ok(DataValue::Integer((n / d).trunc() as i64))
390 }
391 }
392 }
393 "POWER" | "POW" => {
394 if args.len() != 2 {
395 return Err(anyhow!("POWER requires exactly 2 arguments"));
396 }
397
398 let base = self.evaluate(&args[0], row_index)?;
399 let exponent = self.evaluate(&args[1], row_index)?;
400
401 match (&base, &exponent) {
402 (DataValue::Integer(b), DataValue::Integer(e)) => {
403 if *e >= 0 && *e <= i32::MAX as i64 {
404 Ok(DataValue::Float((*b as f64).powi(*e as i32)))
405 } else {
406 Ok(DataValue::Float((*b as f64).powf(*e as f64)))
407 }
408 }
409 _ => {
410 let b = match base {
412 DataValue::Integer(i) => i as f64,
413 DataValue::Float(f) => f,
414 _ => return Err(anyhow!("POWER requires numeric arguments")),
415 };
416 let e = match exponent {
417 DataValue::Integer(i) => i as f64,
418 DataValue::Float(f) => f,
419 _ => return Err(anyhow!("POWER requires numeric arguments")),
420 };
421 Ok(DataValue::Float(b.powf(e)))
422 }
423 }
424 }
425 "SQRT" => {
426 if args.len() != 1 {
427 return Err(anyhow!("SQRT requires exactly 1 argument"));
428 }
429
430 let value = self.evaluate(&args[0], row_index)?;
431 match value {
432 DataValue::Integer(n) => {
433 if n < 0 {
434 return Err(anyhow!("SQRT of negative number"));
435 }
436 Ok(DataValue::Float((n as f64).sqrt()))
437 }
438 DataValue::Float(f) => {
439 if f < 0.0 {
440 return Err(anyhow!("SQRT of negative number"));
441 }
442 Ok(DataValue::Float(f.sqrt()))
443 }
444 _ => Err(anyhow!("SQRT can only be applied to numeric values")),
445 }
446 }
447 "EXP" => {
448 if args.len() != 1 {
449 return Err(anyhow!("EXP requires exactly 1 argument"));
450 }
451
452 let value = self.evaluate(&args[0], row_index)?;
453 match value {
454 DataValue::Integer(n) => Ok(DataValue::Float((n as f64).exp())),
455 DataValue::Float(f) => Ok(DataValue::Float(f.exp())),
456 _ => Err(anyhow!("EXP can only be applied to numeric values")),
457 }
458 }
459 "LN" => {
460 if args.len() != 1 {
461 return Err(anyhow!("LN requires exactly 1 argument"));
462 }
463
464 let value = self.evaluate(&args[0], row_index)?;
465 match value {
466 DataValue::Integer(n) => {
467 if n <= 0 {
468 return Err(anyhow!("LN of non-positive number"));
469 }
470 Ok(DataValue::Float((n as f64).ln()))
471 }
472 DataValue::Float(f) => {
473 if f <= 0.0 {
474 return Err(anyhow!("LN of non-positive number"));
475 }
476 Ok(DataValue::Float(f.ln()))
477 }
478 _ => Err(anyhow!("LN can only be applied to numeric values")),
479 }
480 }
481 "LOG" | "LOG10" => {
482 if name == "LOG" && args.len() == 2 {
483 let value = self.evaluate(&args[0], row_index)?;
485 let base = self.evaluate(&args[1], row_index)?;
486
487 let n = match value {
488 DataValue::Integer(i) => i as f64,
489 DataValue::Float(f) => f,
490 _ => return Err(anyhow!("LOG requires numeric arguments")),
491 };
492 let b = match base {
493 DataValue::Integer(i) => i as f64,
494 DataValue::Float(f) => f,
495 _ => return Err(anyhow!("LOG requires numeric arguments")),
496 };
497
498 if n <= 0.0 {
499 return Err(anyhow!("LOG of non-positive number"));
500 }
501 if b <= 0.0 || b == 1.0 {
502 return Err(anyhow!("Invalid LOG base"));
503 }
504 Ok(DataValue::Float(n.log(b)))
505 } else if (name == "LOG" && args.len() == 1) || name == "LOG10" {
506 if args.len() != 1 {
508 return Err(anyhow!("{} requires exactly 1 argument", name));
509 }
510
511 let value = self.evaluate(&args[0], row_index)?;
512 match value {
513 DataValue::Integer(n) => {
514 if n <= 0 {
515 return Err(anyhow!("LOG10 of non-positive number"));
516 }
517 Ok(DataValue::Float((n as f64).log10()))
518 }
519 DataValue::Float(f) => {
520 if f <= 0.0 {
521 return Err(anyhow!("LOG10 of non-positive number"));
522 }
523 Ok(DataValue::Float(f.log10()))
524 }
525 _ => Err(anyhow!("LOG10 can only be applied to numeric values")),
526 }
527 } else {
528 Err(anyhow!("LOG requires 1 or 2 arguments"))
529 }
530 }
531 "PI" => {
532 if !args.is_empty() {
533 return Err(anyhow!("PI takes no arguments"));
534 }
535 Ok(DataValue::Float(std::f64::consts::PI))
536 }
537 "DATEDIFF" => {
538 if args.len() != 3 {
539 return Err(anyhow!(
540 "DATEDIFF requires exactly 3 arguments: unit, date1, date2"
541 ));
542 }
543
544 let unit = match self.evaluate(&args[0], row_index)? {
546 DataValue::String(s) => s.to_lowercase(),
547 DataValue::InternedString(s) => s.to_lowercase(),
548 _ => return Err(anyhow!("DATEDIFF unit must be a string")),
549 };
550
551 let parse_datetime = |value: DataValue| -> Result<DateTime<Utc>> {
553 let parse_string = |s: &str| -> Result<DateTime<Utc>> {
554 if let Ok(dt) = NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S") {
558 return Ok(Utc.from_utc_datetime(&dt));
559 }
560 if let Ok(dt) = NaiveDate::parse_from_str(s, "%Y-%m-%d") {
561 return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
562 }
563
564 if let Ok(dt) = NaiveDate::parse_from_str(s, "%m/%d/%Y") {
566 return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
567 }
568 if let Ok(dt) = NaiveDate::parse_from_str(s, "%m-%d-%Y") {
569 return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
570 }
571
572 if let Ok(dt) = NaiveDate::parse_from_str(s, "%d/%m/%Y") {
574 return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
575 }
576 if let Ok(dt) = NaiveDate::parse_from_str(s, "%d-%m-%Y") {
577 return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
578 }
579
580 if let Ok(dt) = NaiveDate::parse_from_str(s, "%d-%b-%Y") {
582 return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
583 }
584
585 if let Ok(dt) = NaiveDate::parse_from_str(s, "%B %d, %Y") {
587 return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
588 }
589 if let Ok(dt) = NaiveDate::parse_from_str(s, "%d %B %Y") {
590 return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
591 }
592
593 if let Ok(dt) = NaiveDateTime::parse_from_str(s, "%m/%d/%Y %H:%M:%S") {
595 return Ok(Utc.from_utc_datetime(&dt));
596 }
597 if let Ok(dt) = NaiveDateTime::parse_from_str(s, "%d/%m/%Y %H:%M:%S") {
598 return Ok(Utc.from_utc_datetime(&dt));
599 }
600
601 if let Ok(dt) = s.parse::<DateTime<Utc>>() {
603 return Ok(dt);
604 }
605
606 Err(anyhow!("Could not parse date: {}. Supported formats: YYYY-MM-DD, MM/DD/YYYY, DD/MM/YYYY, DD-MMM-YYYY", s))
607 };
608
609 match value {
610 DataValue::String(s) | DataValue::DateTime(s) => parse_string(&s),
611 DataValue::InternedString(s) => parse_string(s.as_str()),
612 _ => Err(anyhow!("DATEDIFF requires date/datetime values")),
613 }
614 };
615
616 let date1 = parse_datetime(self.evaluate(&args[1], row_index)?)?;
618 let date2 = parse_datetime(self.evaluate(&args[2], row_index)?)?;
619
620 let diff = match unit.as_str() {
622 "day" | "days" => {
623 let duration = date2.signed_duration_since(date1);
624 duration.num_days()
625 }
626 "month" | "months" => {
627 let duration = date2.signed_duration_since(date1);
629 duration.num_days() / 30
630 }
631 "year" | "years" => {
632 let duration = date2.signed_duration_since(date1);
634 duration.num_days() / 365
635 }
636 "hour" | "hours" => {
637 let duration = date2.signed_duration_since(date1);
638 duration.num_hours()
639 }
640 "minute" | "minutes" => {
641 let duration = date2.signed_duration_since(date1);
642 duration.num_minutes()
643 }
644 "second" | "seconds" => {
645 let duration = date2.signed_duration_since(date1);
646 duration.num_seconds()
647 }
648 _ => {
649 return Err(anyhow!(
650 "Unknown DATEDIFF unit: {}. Use: day, month, year, hour, minute, second",
651 unit
652 ))
653 }
654 };
655
656 Ok(DataValue::Integer(diff))
657 }
658 "NOW" => {
659 if !args.is_empty() {
660 return Err(anyhow!("NOW takes no arguments"));
661 }
662 let now = Utc::now();
663 Ok(DataValue::DateTime(
664 now.format("%Y-%m-%d %H:%M:%S").to_string(),
665 ))
666 }
667 "TODAY" => {
668 if !args.is_empty() {
669 return Err(anyhow!("TODAY takes no arguments"));
670 }
671 let today = Utc::now().date_naive();
672 Ok(DataValue::String(today.format("%Y-%m-%d").to_string()))
673 }
674 "DATEADD" => {
675 if args.len() != 3 {
676 return Err(anyhow!(
677 "DATEADD requires exactly 3 arguments: unit, number, date"
678 ));
679 }
680
681 let unit = match self.evaluate(&args[0], row_index)? {
683 DataValue::String(s) => s.to_lowercase(),
684 DataValue::InternedString(s) => s.to_lowercase(),
685 _ => return Err(anyhow!("DATEADD unit must be a string")),
686 };
687
688 let amount = match self.evaluate(&args[1], row_index)? {
690 DataValue::Integer(i) => i,
691 DataValue::Float(f) => f as i64,
692 _ => return Err(anyhow!("DATEADD amount must be a number")),
693 };
694
695 let base_date_value = self.evaluate(&args[2], row_index)?;
697
698 let parse_datetime = |value: DataValue| -> Result<DateTime<Utc>> {
700 let parse_string = |s: &str| -> Result<DateTime<Utc>> {
701 if let Ok(dt) = NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S") {
705 return Ok(Utc.from_utc_datetime(&dt));
706 }
707 if let Ok(dt) = NaiveDate::parse_from_str(s, "%Y-%m-%d") {
708 return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
709 }
710
711 if let Ok(dt) = NaiveDate::parse_from_str(s, "%m/%d/%Y") {
713 return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
714 }
715 if let Ok(dt) = NaiveDate::parse_from_str(s, "%m-%d-%Y") {
716 return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
717 }
718
719 if let Ok(dt) = NaiveDate::parse_from_str(s, "%d/%m/%Y") {
721 return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
722 }
723 if let Ok(dt) = NaiveDate::parse_from_str(s, "%d-%m-%Y") {
724 return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
725 }
726
727 if let Ok(dt) = NaiveDate::parse_from_str(s, "%d-%b-%Y") {
729 return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
730 }
731
732 if let Ok(dt) = NaiveDate::parse_from_str(s, "%B %d, %Y") {
734 return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
735 }
736 if let Ok(dt) = NaiveDate::parse_from_str(s, "%d %B %Y") {
737 return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
738 }
739
740 if let Ok(dt) = NaiveDateTime::parse_from_str(s, "%m/%d/%Y %H:%M:%S") {
742 return Ok(Utc.from_utc_datetime(&dt));
743 }
744 if let Ok(dt) = NaiveDateTime::parse_from_str(s, "%d/%m/%Y %H:%M:%S") {
745 return Ok(Utc.from_utc_datetime(&dt));
746 }
747
748 if let Ok(dt) = s.parse::<DateTime<Utc>>() {
750 return Ok(dt);
751 }
752
753 Err(anyhow!("Could not parse date: {}. Supported formats: YYYY-MM-DD, MM/DD/YYYY, DD/MM/YYYY, DD-MMM-YYYY", s))
754 };
755
756 match value {
757 DataValue::String(s) | DataValue::DateTime(s) => parse_string(&s),
758 DataValue::InternedString(s) => parse_string(s.as_str()),
759 _ => Err(anyhow!("DATEADD requires date/datetime values")),
760 }
761 };
762
763 let base_date = parse_datetime(base_date_value)?;
765
766 let result_date = match unit.as_str() {
768 "day" | "days" => base_date + chrono::Duration::days(amount),
769 "month" | "months" => {
770 let mut year = base_date.year();
772 let mut month = base_date.month() as i32;
773 let day = base_date.day();
774
775 month += amount as i32;
776
777 while month > 12 {
779 month -= 12;
780 year += 1;
781 }
782 while month < 1 {
783 month += 12;
784 year -= 1;
785 }
786
787 let target_date = NaiveDate::from_ymd_opt(year, month as u32, day)
789 .unwrap_or_else(|| {
790 for test_day in (1..=day).rev() {
793 if let Some(date) =
794 NaiveDate::from_ymd_opt(year, month as u32, test_day)
795 {
796 return date;
797 }
798 }
799 NaiveDate::from_ymd_opt(year, month as u32, 28).unwrap()
801 });
802
803 Utc.from_utc_datetime(&target_date.and_time(base_date.time()))
804 }
805 "year" | "years" => {
806 let new_year = base_date.year() + amount as i32;
807 let target_date =
808 NaiveDate::from_ymd_opt(new_year, base_date.month(), base_date.day())
809 .unwrap_or_else(|| {
810 NaiveDate::from_ymd_opt(new_year, base_date.month(), 28)
812 .unwrap()
813 });
814 Utc.from_utc_datetime(&target_date.and_time(base_date.time()))
815 }
816 "hour" | "hours" => base_date + chrono::Duration::hours(amount),
817 "minute" | "minutes" => base_date + chrono::Duration::minutes(amount),
818 "second" | "seconds" => base_date + chrono::Duration::seconds(amount),
819 _ => {
820 return Err(anyhow!(
821 "Unknown DATEADD unit: {}. Use: day, month, year, hour, minute, second",
822 unit
823 ))
824 }
825 };
826
827 Ok(DataValue::DateTime(
829 result_date.format("%Y-%m-%d %H:%M:%S").to_string(),
830 ))
831 }
832 "TEXTJOIN" => {
833 if args.len() < 3 {
834 return Err(anyhow!("TEXTJOIN requires at least 3 arguments: delimiter, ignore_empty, text1, [text2, ...]"));
835 }
836
837 let delimiter = match self.evaluate(&args[0], row_index)? {
839 DataValue::String(s) => s,
840 DataValue::InternedString(s) => s.to_string(),
841 DataValue::Integer(n) => n.to_string(),
842 DataValue::Float(f) => f.to_string(),
843 DataValue::Boolean(b) => b.to_string(),
844 DataValue::Null => String::new(),
845 _ => String::new(),
846 };
847
848 let ignore_empty = match self.evaluate(&args[1], row_index)? {
850 DataValue::Integer(n) => n != 0,
851 DataValue::Float(f) => f != 0.0,
852 DataValue::Boolean(b) => b,
853 DataValue::String(s) => {
854 !s.is_empty() && s != "0" && s.to_lowercase() != "false"
855 }
856 DataValue::InternedString(s) => {
857 !s.is_empty() && s.as_str() != "0" && s.to_lowercase() != "false"
858 }
859 DataValue::Null => false,
860 _ => true,
861 };
862
863 let mut values = Vec::new();
865 for i in 2..args.len() {
866 let value = self.evaluate(&args[i], row_index)?;
867 let string_value = match value {
868 DataValue::String(s) => Some(s),
869 DataValue::InternedString(s) => Some(s.to_string()),
870 DataValue::Integer(n) => Some(n.to_string()),
871 DataValue::Float(f) => Some(f.to_string()),
872 DataValue::Boolean(b) => Some(b.to_string()),
873 DataValue::DateTime(dt) => Some(dt),
874 DataValue::Null => {
875 if ignore_empty {
876 None
877 } else {
878 Some(String::new())
879 }
880 }
881 _ => {
882 if ignore_empty {
883 None
884 } else {
885 Some(String::new())
886 }
887 }
888 };
889
890 if let Some(s) = string_value {
891 if !ignore_empty || !s.is_empty() {
892 values.push(s);
893 }
894 }
895 }
896
897 Ok(DataValue::String(values.join(&delimiter)))
898 }
899 _ => Err(anyhow!("Unknown function: {}", name)),
900 }
901 }
902}
903
904#[cfg(test)]
905mod tests {
906 use super::*;
907 use crate::data::datatable::{DataColumn, DataRow};
908
909 fn create_test_table() -> DataTable {
910 let mut table = DataTable::new("test");
911 table.add_column(DataColumn::new("a"));
912 table.add_column(DataColumn::new("b"));
913 table.add_column(DataColumn::new("c"));
914
915 table
916 .add_row(DataRow::new(vec![
917 DataValue::Integer(10),
918 DataValue::Float(2.5),
919 DataValue::Integer(4),
920 ]))
921 .unwrap();
922
923 table
924 }
925
926 #[test]
927 fn test_evaluate_column() {
928 let table = create_test_table();
929 let evaluator = ArithmeticEvaluator::new(&table);
930
931 let expr = SqlExpression::Column("a".to_string());
932 let result = evaluator.evaluate(&expr, 0).unwrap();
933 assert_eq!(result, DataValue::Integer(10));
934 }
935
936 #[test]
937 fn test_evaluate_number_literal() {
938 let table = create_test_table();
939 let evaluator = ArithmeticEvaluator::new(&table);
940
941 let expr = SqlExpression::NumberLiteral("42".to_string());
942 let result = evaluator.evaluate(&expr, 0).unwrap();
943 assert_eq!(result, DataValue::Integer(42));
944
945 let expr = SqlExpression::NumberLiteral("3.14".to_string());
946 let result = evaluator.evaluate(&expr, 0).unwrap();
947 assert_eq!(result, DataValue::Float(3.14));
948 }
949
950 #[test]
951 fn test_add_values() {
952 let table = create_test_table();
953 let evaluator = ArithmeticEvaluator::new(&table);
954
955 let result = evaluator
957 .add_values(&DataValue::Integer(5), &DataValue::Integer(3))
958 .unwrap();
959 assert_eq!(result, DataValue::Integer(8));
960
961 let result = evaluator
963 .add_values(&DataValue::Integer(5), &DataValue::Float(2.5))
964 .unwrap();
965 assert_eq!(result, DataValue::Float(7.5));
966 }
967
968 #[test]
969 fn test_multiply_values() {
970 let table = create_test_table();
971 let evaluator = ArithmeticEvaluator::new(&table);
972
973 let result = evaluator
975 .multiply_values(&DataValue::Integer(4), &DataValue::Float(2.5))
976 .unwrap();
977 assert_eq!(result, DataValue::Float(10.0));
978 }
979
980 #[test]
981 fn test_divide_values() {
982 let table = create_test_table();
983 let evaluator = ArithmeticEvaluator::new(&table);
984
985 let result = evaluator
987 .divide_values(&DataValue::Integer(10), &DataValue::Integer(2))
988 .unwrap();
989 assert_eq!(result, DataValue::Integer(5));
990
991 let result = evaluator
993 .divide_values(&DataValue::Integer(10), &DataValue::Integer(3))
994 .unwrap();
995 assert_eq!(result, DataValue::Float(10.0 / 3.0));
996 }
997
998 #[test]
999 fn test_division_by_zero() {
1000 let table = create_test_table();
1001 let evaluator = ArithmeticEvaluator::new(&table);
1002
1003 let result = evaluator.divide_values(&DataValue::Integer(10), &DataValue::Integer(0));
1004 assert!(result.is_err());
1005 assert!(result.unwrap_err().to_string().contains("Division by zero"));
1006 }
1007
1008 #[test]
1009 fn test_binary_op_expression() {
1010 let table = create_test_table();
1011 let evaluator = ArithmeticEvaluator::new(&table);
1012
1013 let expr = SqlExpression::BinaryOp {
1015 left: Box::new(SqlExpression::Column("a".to_string())),
1016 op: "*".to_string(),
1017 right: Box::new(SqlExpression::Column("b".to_string())),
1018 };
1019
1020 let result = evaluator.evaluate(&expr, 0).unwrap();
1021 assert_eq!(result, DataValue::Float(25.0));
1022 }
1023}