1use crate::data::datatable::{DataTable, DataValue};
2use crate::sql::recursive_parser::SqlExpression;
3use anyhow::{anyhow, Result};
4use chrono::{DateTime, Datelike, NaiveDate, NaiveDateTime, TimeZone, Utc};
5use tracing::debug;
6
7pub struct ArithmeticEvaluator<'a> {
10 table: &'a DataTable,
11}
12
13impl<'a> ArithmeticEvaluator<'a> {
14 pub fn new(table: &'a DataTable) -> Self {
15 Self { table }
16 }
17
18 fn find_similar_column(&self, name: &str) -> Option<String> {
20 let columns = self.table.column_names();
21 let mut best_match: Option<(String, usize)> = None;
22
23 for col in columns {
24 let distance = self.edit_distance(&col.to_lowercase(), &name.to_lowercase());
25 let max_distance = if name.len() > 10 { 3 } else { 2 };
28 if distance <= max_distance {
29 match &best_match {
30 None => best_match = Some((col, distance)),
31 Some((_, best_dist)) if distance < *best_dist => {
32 best_match = Some((col, distance));
33 }
34 _ => {}
35 }
36 }
37 }
38
39 best_match.map(|(name, _)| name)
40 }
41
42 fn edit_distance(&self, s1: &str, s2: &str) -> usize {
44 let len1 = s1.len();
45 let len2 = s2.len();
46 let mut matrix = vec![vec![0; len2 + 1]; len1 + 1];
47
48 for i in 0..=len1 {
49 matrix[i][0] = i;
50 }
51 for j in 0..=len2 {
52 matrix[0][j] = j;
53 }
54
55 for (i, c1) in s1.chars().enumerate() {
56 for (j, c2) in s2.chars().enumerate() {
57 let cost = if c1 == c2 { 0 } else { 1 };
58 matrix[i + 1][j + 1] = std::cmp::min(
59 matrix[i][j + 1] + 1, std::cmp::min(
61 matrix[i + 1][j] + 1, matrix[i][j] + cost, ),
64 );
65 }
66 }
67
68 matrix[len1][len2]
69 }
70
71 pub fn evaluate(&self, expr: &SqlExpression, row_index: usize) -> Result<DataValue> {
73 debug!(
74 "ArithmeticEvaluator: evaluating {:?} for row {}",
75 expr, row_index
76 );
77
78 match expr {
79 SqlExpression::Column(column_name) => self.evaluate_column(column_name, row_index),
80 SqlExpression::StringLiteral(s) => Ok(DataValue::String(s.clone())),
81 SqlExpression::NumberLiteral(n) => self.evaluate_number_literal(n),
82 SqlExpression::BinaryOp { left, op, right } => {
83 self.evaluate_binary_op(left, op, right, row_index)
84 }
85 SqlExpression::FunctionCall { name, args } => {
86 self.evaluate_function(name, args, row_index)
87 }
88 SqlExpression::MethodCall {
89 object,
90 method,
91 args,
92 } => self.evaluate_method_call(object, method, args, row_index),
93 SqlExpression::ChainedMethodCall { base, method, args } => {
94 let base_value = self.evaluate(base, row_index)?;
96 self.evaluate_method_on_value(&base_value, method, args, row_index)
97 }
98 _ => Err(anyhow!(
99 "Unsupported expression type for arithmetic evaluation: {:?}",
100 expr
101 )),
102 }
103 }
104
105 fn evaluate_column(&self, column_name: &str, row_index: usize) -> Result<DataValue> {
107 let col_index = self.table.get_column_index(column_name).ok_or_else(|| {
108 let suggestion = self.find_similar_column(column_name);
109 match suggestion {
110 Some(similar) => anyhow!(
111 "Column '{}' not found. Did you mean '{}'?",
112 column_name,
113 similar
114 ),
115 None => anyhow!("Column '{}' not found", column_name),
116 }
117 })?;
118
119 if row_index >= self.table.row_count() {
120 return Err(anyhow!("Row index {} out of bounds", row_index));
121 }
122
123 let row = self
124 .table
125 .get_row(row_index)
126 .ok_or_else(|| anyhow!("Row {} not found", row_index))?;
127
128 let value = row
129 .get(col_index)
130 .ok_or_else(|| anyhow!("Column index {} out of bounds for row", col_index))?;
131
132 Ok(value.clone())
133 }
134
135 fn evaluate_number_literal(&self, number_str: &str) -> Result<DataValue> {
137 if let Ok(int_val) = number_str.parse::<i64>() {
139 return Ok(DataValue::Integer(int_val));
140 }
141
142 if let Ok(float_val) = number_str.parse::<f64>() {
144 return Ok(DataValue::Float(float_val));
145 }
146
147 Err(anyhow!("Invalid number literal: {}", number_str))
148 }
149
150 fn evaluate_binary_op(
152 &self,
153 left: &SqlExpression,
154 op: &str,
155 right: &SqlExpression,
156 row_index: usize,
157 ) -> Result<DataValue> {
158 let left_val = self.evaluate(left, row_index)?;
159 let right_val = self.evaluate(right, row_index)?;
160
161 debug!(
162 "ArithmeticEvaluator: {} {} {}",
163 self.format_value(&left_val),
164 op,
165 self.format_value(&right_val)
166 );
167
168 match op {
169 "+" => self.add_values(&left_val, &right_val),
170 "-" => self.subtract_values(&left_val, &right_val),
171 "*" => self.multiply_values(&left_val, &right_val),
172 "/" => self.divide_values(&left_val, &right_val),
173 _ => Err(anyhow!("Unsupported arithmetic operator: {}", op)),
174 }
175 }
176
177 fn add_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
179 match (left, right) {
180 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a + b)),
181 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 + b)),
182 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a + *b as f64)),
183 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a + b)),
184 _ => Err(anyhow!("Cannot add {:?} and {:?}", left, right)),
185 }
186 }
187
188 fn subtract_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
190 match (left, right) {
191 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a - b)),
192 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 - b)),
193 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a - *b as f64)),
194 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a - b)),
195 _ => Err(anyhow!("Cannot subtract {:?} and {:?}", left, right)),
196 }
197 }
198
199 fn multiply_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
201 match (left, right) {
202 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a * b)),
203 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 * b)),
204 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a * *b as f64)),
205 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a * b)),
206 _ => Err(anyhow!("Cannot multiply {:?} and {:?}", left, right)),
207 }
208 }
209
210 fn divide_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
212 let is_zero = match right {
214 DataValue::Integer(0) => true,
215 DataValue::Float(f) if f.abs() < f64::EPSILON => true,
216 _ => false,
217 };
218
219 if is_zero {
220 return Err(anyhow!("Division by zero"));
221 }
222
223 match (left, right) {
224 (DataValue::Integer(a), DataValue::Integer(b)) => {
225 if a % b == 0 {
227 Ok(DataValue::Integer(a / b))
228 } else {
229 Ok(DataValue::Float(*a as f64 / *b as f64))
230 }
231 }
232 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 / b)),
233 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a / *b as f64)),
234 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a / b)),
235 _ => Err(anyhow!("Cannot divide {:?} and {:?}", left, right)),
236 }
237 }
238
239 fn format_value(&self, value: &DataValue) -> String {
241 match value {
242 DataValue::Integer(i) => i.to_string(),
243 DataValue::Float(f) => f.to_string(),
244 DataValue::String(s) => format!("'{}'", s),
245 _ => format!("{:?}", value),
246 }
247 }
248
249 fn evaluate_function(
251 &self,
252 name: &str,
253 args: &[SqlExpression],
254 row_index: usize,
255 ) -> Result<DataValue> {
256 match name {
257 "ROUND" => {
258 if args.is_empty() || args.len() > 2 {
259 return Err(anyhow!("ROUND requires 1 or 2 arguments"));
260 }
261
262 let value = self.evaluate(&args[0], row_index)?;
264
265 let decimals = if args.len() == 2 {
267 match self.evaluate(&args[1], row_index)? {
268 DataValue::Integer(n) => n as i32,
269 DataValue::Float(f) => f as i32,
270 _ => return Err(anyhow!("ROUND precision must be a number")),
271 }
272 } else {
273 0
274 };
275
276 match value {
278 DataValue::Integer(n) => Ok(DataValue::Integer(n)), DataValue::Float(f) => {
280 if decimals >= 0 {
281 let multiplier = 10_f64.powi(decimals);
282 let rounded = (f * multiplier).round() / multiplier;
283 if decimals == 0 {
284 Ok(DataValue::Integer(rounded as i64))
286 } else {
287 Ok(DataValue::Float(rounded))
288 }
289 } else {
290 let divisor = 10_f64.powi(-decimals);
292 let rounded = (f / divisor).round() * divisor;
293 Ok(DataValue::Float(rounded))
294 }
295 }
296 _ => Err(anyhow!("ROUND can only be applied to numeric values")),
297 }
298 }
299 "ABS" => {
300 if args.len() != 1 {
301 return Err(anyhow!("ABS requires exactly 1 argument"));
302 }
303
304 let value = self.evaluate(&args[0], row_index)?;
305 match value {
306 DataValue::Integer(n) => Ok(DataValue::Integer(n.abs())),
307 DataValue::Float(f) => Ok(DataValue::Float(f.abs())),
308 _ => Err(anyhow!("ABS can only be applied to numeric values")),
309 }
310 }
311 "FLOOR" => {
312 if args.len() != 1 {
313 return Err(anyhow!("FLOOR requires exactly 1 argument"));
314 }
315
316 let value = self.evaluate(&args[0], row_index)?;
317 match value {
318 DataValue::Integer(n) => Ok(DataValue::Integer(n)),
319 DataValue::Float(f) => Ok(DataValue::Integer(f.floor() as i64)),
320 _ => Err(anyhow!("FLOOR can only be applied to numeric values")),
321 }
322 }
323 "CEILING" | "CEIL" => {
324 if args.len() != 1 {
325 return Err(anyhow!("CEILING requires exactly 1 argument"));
326 }
327
328 let value = self.evaluate(&args[0], row_index)?;
329 match value {
330 DataValue::Integer(n) => Ok(DataValue::Integer(n)),
331 DataValue::Float(f) => Ok(DataValue::Integer(f.ceil() as i64)),
332 _ => Err(anyhow!("CEILING can only be applied to numeric values")),
333 }
334 }
335 "MOD" => {
336 if args.len() != 2 {
337 return Err(anyhow!("MOD requires exactly 2 arguments"));
338 }
339
340 let dividend = self.evaluate(&args[0], row_index)?;
341 let divisor = self.evaluate(&args[1], row_index)?;
342
343 match (÷nd, &divisor) {
344 (DataValue::Integer(n), DataValue::Integer(d)) => {
345 if *d == 0 {
346 return Err(anyhow!("Division by zero in MOD"));
347 }
348 Ok(DataValue::Integer(n % d))
349 }
350 _ => {
351 let n = match dividend {
353 DataValue::Integer(i) => i as f64,
354 DataValue::Float(f) => f,
355 _ => return Err(anyhow!("MOD requires numeric arguments")),
356 };
357 let d = match divisor {
358 DataValue::Integer(i) => i as f64,
359 DataValue::Float(f) => f,
360 _ => return Err(anyhow!("MOD requires numeric arguments")),
361 };
362 if d == 0.0 {
363 return Err(anyhow!("Division by zero in MOD"));
364 }
365 Ok(DataValue::Float(n % d))
366 }
367 }
368 }
369 "QUOTIENT" => {
370 if args.len() != 2 {
371 return Err(anyhow!("QUOTIENT requires exactly 2 arguments"));
372 }
373
374 let numerator = self.evaluate(&args[0], row_index)?;
375 let denominator = self.evaluate(&args[1], row_index)?;
376
377 match (&numerator, &denominator) {
378 (DataValue::Integer(n), DataValue::Integer(d)) => {
379 if *d == 0 {
380 return Err(anyhow!("Division by zero in QUOTIENT"));
381 }
382 Ok(DataValue::Integer(n / d))
383 }
384 _ => {
385 let n = match numerator {
387 DataValue::Integer(i) => i as f64,
388 DataValue::Float(f) => f,
389 _ => return Err(anyhow!("QUOTIENT requires numeric arguments")),
390 };
391 let d = match denominator {
392 DataValue::Integer(i) => i as f64,
393 DataValue::Float(f) => f,
394 _ => return Err(anyhow!("QUOTIENT requires numeric arguments")),
395 };
396 if d == 0.0 {
397 return Err(anyhow!("Division by zero in QUOTIENT"));
398 }
399 Ok(DataValue::Integer((n / d).trunc() as i64))
400 }
401 }
402 }
403 "POWER" | "POW" => {
404 if args.len() != 2 {
405 return Err(anyhow!("POWER requires exactly 2 arguments"));
406 }
407
408 let base = self.evaluate(&args[0], row_index)?;
409 let exponent = self.evaluate(&args[1], row_index)?;
410
411 match (&base, &exponent) {
412 (DataValue::Integer(b), DataValue::Integer(e)) => {
413 if *e >= 0 && *e <= i32::MAX as i64 {
414 Ok(DataValue::Float((*b as f64).powi(*e as i32)))
415 } else {
416 Ok(DataValue::Float((*b as f64).powf(*e as f64)))
417 }
418 }
419 _ => {
420 let b = match base {
422 DataValue::Integer(i) => i as f64,
423 DataValue::Float(f) => f,
424 _ => return Err(anyhow!("POWER requires numeric arguments")),
425 };
426 let e = match exponent {
427 DataValue::Integer(i) => i as f64,
428 DataValue::Float(f) => f,
429 _ => return Err(anyhow!("POWER requires numeric arguments")),
430 };
431 Ok(DataValue::Float(b.powf(e)))
432 }
433 }
434 }
435 "SQRT" => {
436 if args.len() != 1 {
437 return Err(anyhow!("SQRT requires exactly 1 argument"));
438 }
439
440 let value = self.evaluate(&args[0], row_index)?;
441 match value {
442 DataValue::Integer(n) => {
443 if n < 0 {
444 return Err(anyhow!("SQRT of negative number"));
445 }
446 Ok(DataValue::Float((n as f64).sqrt()))
447 }
448 DataValue::Float(f) => {
449 if f < 0.0 {
450 return Err(anyhow!("SQRT of negative number"));
451 }
452 Ok(DataValue::Float(f.sqrt()))
453 }
454 _ => Err(anyhow!("SQRT can only be applied to numeric values")),
455 }
456 }
457 "EXP" => {
458 if args.len() != 1 {
459 return Err(anyhow!("EXP requires exactly 1 argument"));
460 }
461
462 let value = self.evaluate(&args[0], row_index)?;
463 match value {
464 DataValue::Integer(n) => Ok(DataValue::Float((n as f64).exp())),
465 DataValue::Float(f) => Ok(DataValue::Float(f.exp())),
466 _ => Err(anyhow!("EXP can only be applied to numeric values")),
467 }
468 }
469 "LN" => {
470 if args.len() != 1 {
471 return Err(anyhow!("LN requires exactly 1 argument"));
472 }
473
474 let value = self.evaluate(&args[0], row_index)?;
475 match value {
476 DataValue::Integer(n) => {
477 if n <= 0 {
478 return Err(anyhow!("LN of non-positive number"));
479 }
480 Ok(DataValue::Float((n as f64).ln()))
481 }
482 DataValue::Float(f) => {
483 if f <= 0.0 {
484 return Err(anyhow!("LN of non-positive number"));
485 }
486 Ok(DataValue::Float(f.ln()))
487 }
488 _ => Err(anyhow!("LN can only be applied to numeric values")),
489 }
490 }
491 "LOG" | "LOG10" => {
492 if name == "LOG" && args.len() == 2 {
493 let value = self.evaluate(&args[0], row_index)?;
495 let base = self.evaluate(&args[1], row_index)?;
496
497 let n = match value {
498 DataValue::Integer(i) => i as f64,
499 DataValue::Float(f) => f,
500 _ => return Err(anyhow!("LOG requires numeric arguments")),
501 };
502 let b = match base {
503 DataValue::Integer(i) => i as f64,
504 DataValue::Float(f) => f,
505 _ => return Err(anyhow!("LOG requires numeric arguments")),
506 };
507
508 if n <= 0.0 {
509 return Err(anyhow!("LOG of non-positive number"));
510 }
511 if b <= 0.0 || b == 1.0 {
512 return Err(anyhow!("Invalid LOG base"));
513 }
514 Ok(DataValue::Float(n.log(b)))
515 } else if (name == "LOG" && args.len() == 1) || name == "LOG10" {
516 if args.len() != 1 {
518 return Err(anyhow!("{} requires exactly 1 argument", name));
519 }
520
521 let value = self.evaluate(&args[0], row_index)?;
522 match value {
523 DataValue::Integer(n) => {
524 if n <= 0 {
525 return Err(anyhow!("LOG10 of non-positive number"));
526 }
527 Ok(DataValue::Float((n as f64).log10()))
528 }
529 DataValue::Float(f) => {
530 if f <= 0.0 {
531 return Err(anyhow!("LOG10 of non-positive number"));
532 }
533 Ok(DataValue::Float(f.log10()))
534 }
535 _ => Err(anyhow!("LOG10 can only be applied to numeric values")),
536 }
537 } else {
538 Err(anyhow!("LOG requires 1 or 2 arguments"))
539 }
540 }
541 "PI" => {
542 if !args.is_empty() {
543 return Err(anyhow!("PI takes no arguments"));
544 }
545 Ok(DataValue::Float(std::f64::consts::PI))
546 }
547 "DATEDIFF" => {
548 if args.len() != 3 {
549 return Err(anyhow!(
550 "DATEDIFF requires exactly 3 arguments: unit, date1, date2"
551 ));
552 }
553
554 let unit = match self.evaluate(&args[0], row_index)? {
556 DataValue::String(s) => s.to_lowercase(),
557 DataValue::InternedString(s) => s.to_lowercase(),
558 _ => return Err(anyhow!("DATEDIFF unit must be a string")),
559 };
560
561 let parse_datetime = |value: DataValue| -> Result<DateTime<Utc>> {
563 let parse_string = |s: &str| -> Result<DateTime<Utc>> {
564 if let Ok(dt) = NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S") {
568 return Ok(Utc.from_utc_datetime(&dt));
569 }
570 if let Ok(dt) = NaiveDate::parse_from_str(s, "%Y-%m-%d") {
571 return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
572 }
573
574 if let Ok(dt) = NaiveDate::parse_from_str(s, "%m/%d/%Y") {
576 return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
577 }
578 if let Ok(dt) = NaiveDate::parse_from_str(s, "%m-%d-%Y") {
579 return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
580 }
581
582 if let Ok(dt) = NaiveDate::parse_from_str(s, "%d/%m/%Y") {
584 return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
585 }
586 if let Ok(dt) = NaiveDate::parse_from_str(s, "%d-%m-%Y") {
587 return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
588 }
589
590 if let Ok(dt) = NaiveDate::parse_from_str(s, "%d-%b-%Y") {
592 return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
593 }
594
595 if let Ok(dt) = NaiveDate::parse_from_str(s, "%B %d, %Y") {
597 return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
598 }
599 if let Ok(dt) = NaiveDate::parse_from_str(s, "%d %B %Y") {
600 return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
601 }
602
603 if let Ok(dt) = NaiveDateTime::parse_from_str(s, "%m/%d/%Y %H:%M:%S") {
605 return Ok(Utc.from_utc_datetime(&dt));
606 }
607 if let Ok(dt) = NaiveDateTime::parse_from_str(s, "%d/%m/%Y %H:%M:%S") {
608 return Ok(Utc.from_utc_datetime(&dt));
609 }
610
611 if let Ok(dt) = s.parse::<DateTime<Utc>>() {
613 return Ok(dt);
614 }
615
616 Err(anyhow!("Could not parse date: {}. Supported formats: YYYY-MM-DD, MM/DD/YYYY, DD/MM/YYYY, DD-MMM-YYYY", s))
617 };
618
619 match value {
620 DataValue::String(s) | DataValue::DateTime(s) => parse_string(&s),
621 DataValue::InternedString(s) => parse_string(s.as_str()),
622 _ => Err(anyhow!("DATEDIFF requires date/datetime values")),
623 }
624 };
625
626 let date1 = parse_datetime(self.evaluate(&args[1], row_index)?)?;
628 let date2 = parse_datetime(self.evaluate(&args[2], row_index)?)?;
629
630 let diff = match unit.as_str() {
632 "day" | "days" => {
633 let duration = date2.signed_duration_since(date1);
634 duration.num_days()
635 }
636 "month" | "months" => {
637 let duration = date2.signed_duration_since(date1);
639 duration.num_days() / 30
640 }
641 "year" | "years" => {
642 let duration = date2.signed_duration_since(date1);
644 duration.num_days() / 365
645 }
646 "hour" | "hours" => {
647 let duration = date2.signed_duration_since(date1);
648 duration.num_hours()
649 }
650 "minute" | "minutes" => {
651 let duration = date2.signed_duration_since(date1);
652 duration.num_minutes()
653 }
654 "second" | "seconds" => {
655 let duration = date2.signed_duration_since(date1);
656 duration.num_seconds()
657 }
658 _ => {
659 return Err(anyhow!(
660 "Unknown DATEDIFF unit: {}. Use: day, month, year, hour, minute, second",
661 unit
662 ))
663 }
664 };
665
666 Ok(DataValue::Integer(diff))
667 }
668 "NOW" => {
669 if !args.is_empty() {
670 return Err(anyhow!("NOW takes no arguments"));
671 }
672 let now = Utc::now();
673 Ok(DataValue::DateTime(
674 now.format("%Y-%m-%d %H:%M:%S").to_string(),
675 ))
676 }
677 "TODAY" => {
678 if !args.is_empty() {
679 return Err(anyhow!("TODAY takes no arguments"));
680 }
681 let today = Utc::now().date_naive();
682 Ok(DataValue::String(today.format("%Y-%m-%d").to_string()))
683 }
684 "DATEADD" => {
685 if args.len() != 3 {
686 return Err(anyhow!(
687 "DATEADD requires exactly 3 arguments: unit, number, date"
688 ));
689 }
690
691 let unit = match self.evaluate(&args[0], row_index)? {
693 DataValue::String(s) => s.to_lowercase(),
694 DataValue::InternedString(s) => s.to_lowercase(),
695 _ => return Err(anyhow!("DATEADD unit must be a string")),
696 };
697
698 let amount = match self.evaluate(&args[1], row_index)? {
700 DataValue::Integer(i) => i,
701 DataValue::Float(f) => f as i64,
702 _ => return Err(anyhow!("DATEADD amount must be a number")),
703 };
704
705 let base_date_value = self.evaluate(&args[2], row_index)?;
707
708 let parse_datetime = |value: DataValue| -> Result<DateTime<Utc>> {
710 let parse_string = |s: &str| -> Result<DateTime<Utc>> {
711 if let Ok(dt) = NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S") {
715 return Ok(Utc.from_utc_datetime(&dt));
716 }
717 if let Ok(dt) = NaiveDate::parse_from_str(s, "%Y-%m-%d") {
718 return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
719 }
720
721 if let Ok(dt) = NaiveDate::parse_from_str(s, "%m/%d/%Y") {
723 return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
724 }
725 if let Ok(dt) = NaiveDate::parse_from_str(s, "%m-%d-%Y") {
726 return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
727 }
728
729 if let Ok(dt) = NaiveDate::parse_from_str(s, "%d/%m/%Y") {
731 return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
732 }
733 if let Ok(dt) = NaiveDate::parse_from_str(s, "%d-%m-%Y") {
734 return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
735 }
736
737 if let Ok(dt) = NaiveDate::parse_from_str(s, "%d-%b-%Y") {
739 return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
740 }
741
742 if let Ok(dt) = NaiveDate::parse_from_str(s, "%B %d, %Y") {
744 return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
745 }
746 if let Ok(dt) = NaiveDate::parse_from_str(s, "%d %B %Y") {
747 return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
748 }
749
750 if let Ok(dt) = NaiveDateTime::parse_from_str(s, "%m/%d/%Y %H:%M:%S") {
752 return Ok(Utc.from_utc_datetime(&dt));
753 }
754 if let Ok(dt) = NaiveDateTime::parse_from_str(s, "%d/%m/%Y %H:%M:%S") {
755 return Ok(Utc.from_utc_datetime(&dt));
756 }
757
758 if let Ok(dt) = s.parse::<DateTime<Utc>>() {
760 return Ok(dt);
761 }
762
763 Err(anyhow!("Could not parse date: {}. Supported formats: YYYY-MM-DD, MM/DD/YYYY, DD/MM/YYYY, DD-MMM-YYYY", s))
764 };
765
766 match value {
767 DataValue::String(s) | DataValue::DateTime(s) => parse_string(&s),
768 DataValue::InternedString(s) => parse_string(s.as_str()),
769 _ => Err(anyhow!("DATEADD requires date/datetime values")),
770 }
771 };
772
773 let base_date = parse_datetime(base_date_value)?;
775
776 let result_date = match unit.as_str() {
778 "day" | "days" => base_date + chrono::Duration::days(amount),
779 "month" | "months" => {
780 let mut year = base_date.year();
782 let mut month = base_date.month() as i32;
783 let day = base_date.day();
784
785 month += amount as i32;
786
787 while month > 12 {
789 month -= 12;
790 year += 1;
791 }
792 while month < 1 {
793 month += 12;
794 year -= 1;
795 }
796
797 let target_date = NaiveDate::from_ymd_opt(year, month as u32, day)
799 .unwrap_or_else(|| {
800 for test_day in (1..=day).rev() {
803 if let Some(date) =
804 NaiveDate::from_ymd_opt(year, month as u32, test_day)
805 {
806 return date;
807 }
808 }
809 NaiveDate::from_ymd_opt(year, month as u32, 28).unwrap()
811 });
812
813 Utc.from_utc_datetime(&target_date.and_time(base_date.time()))
814 }
815 "year" | "years" => {
816 let new_year = base_date.year() + amount as i32;
817 let target_date =
818 NaiveDate::from_ymd_opt(new_year, base_date.month(), base_date.day())
819 .unwrap_or_else(|| {
820 NaiveDate::from_ymd_opt(new_year, base_date.month(), 28)
822 .unwrap()
823 });
824 Utc.from_utc_datetime(&target_date.and_time(base_date.time()))
825 }
826 "hour" | "hours" => base_date + chrono::Duration::hours(amount),
827 "minute" | "minutes" => base_date + chrono::Duration::minutes(amount),
828 "second" | "seconds" => base_date + chrono::Duration::seconds(amount),
829 _ => {
830 return Err(anyhow!(
831 "Unknown DATEADD unit: {}. Use: day, month, year, hour, minute, second",
832 unit
833 ))
834 }
835 };
836
837 Ok(DataValue::DateTime(
839 result_date.format("%Y-%m-%d %H:%M:%S").to_string(),
840 ))
841 }
842 "TEXTJOIN" => {
843 if args.len() < 3 {
844 return Err(anyhow!("TEXTJOIN requires at least 3 arguments: delimiter, ignore_empty, text1, [text2, ...]"));
845 }
846
847 let delimiter = match self.evaluate(&args[0], row_index)? {
849 DataValue::String(s) => s,
850 DataValue::InternedString(s) => s.to_string(),
851 DataValue::Integer(n) => n.to_string(),
852 DataValue::Float(f) => f.to_string(),
853 DataValue::Boolean(b) => b.to_string(),
854 DataValue::Null => String::new(),
855 _ => String::new(),
856 };
857
858 let ignore_empty = match self.evaluate(&args[1], row_index)? {
860 DataValue::Integer(n) => n != 0,
861 DataValue::Float(f) => f != 0.0,
862 DataValue::Boolean(b) => b,
863 DataValue::String(s) => {
864 !s.is_empty() && s != "0" && s.to_lowercase() != "false"
865 }
866 DataValue::InternedString(s) => {
867 !s.is_empty() && s.as_str() != "0" && s.to_lowercase() != "false"
868 }
869 DataValue::Null => false,
870 _ => true,
871 };
872
873 let mut values = Vec::new();
875 for i in 2..args.len() {
876 let value = self.evaluate(&args[i], row_index)?;
877 let string_value = match value {
878 DataValue::String(s) => Some(s),
879 DataValue::InternedString(s) => Some(s.to_string()),
880 DataValue::Integer(n) => Some(n.to_string()),
881 DataValue::Float(f) => Some(f.to_string()),
882 DataValue::Boolean(b) => Some(b.to_string()),
883 DataValue::DateTime(dt) => Some(dt),
884 DataValue::Null => {
885 if ignore_empty {
886 None
887 } else {
888 Some(String::new())
889 }
890 }
891 _ => {
892 if ignore_empty {
893 None
894 } else {
895 Some(String::new())
896 }
897 }
898 };
899
900 if let Some(s) = string_value {
901 if !ignore_empty || !s.is_empty() {
902 values.push(s);
903 }
904 }
905 }
906
907 Ok(DataValue::String(values.join(&delimiter)))
908 }
909 _ => Err(anyhow!("Unknown function: {}", name)),
910 }
911 }
912
913 fn evaluate_method_call(
915 &self,
916 object: &str,
917 method: &str,
918 args: &[SqlExpression],
919 row_index: usize,
920 ) -> Result<DataValue> {
921 let col_index = self.table.get_column_index(object).ok_or_else(|| {
923 let suggestion = self.find_similar_column(object);
924 match suggestion {
925 Some(similar) => {
926 anyhow!("Column '{}' not found. Did you mean '{}'?", object, similar)
927 }
928 None => anyhow!("Column '{}' not found", object),
929 }
930 })?;
931
932 let cell_value = self.table.get_value(row_index, col_index).cloned();
933
934 self.evaluate_method_on_value(
935 &cell_value.unwrap_or(DataValue::Null),
936 method,
937 args,
938 row_index,
939 )
940 }
941
942 fn evaluate_method_on_value(
944 &self,
945 value: &DataValue,
946 method: &str,
947 args: &[SqlExpression],
948 row_index: usize,
949 ) -> Result<DataValue> {
950 match method.to_lowercase().as_str() {
951 "trim" | "trimstart" | "trimend" => {
952 if !args.is_empty() {
953 return Err(anyhow!("{} takes no arguments", method));
954 }
955
956 let str_val = match value {
958 DataValue::String(s) => s.clone(),
959 DataValue::InternedString(s) => s.to_string(),
960 DataValue::Integer(n) => n.to_string(),
961 DataValue::Float(f) => f.to_string(),
962 DataValue::Boolean(b) => b.to_string(),
963 DataValue::DateTime(dt) => dt.clone(),
964 DataValue::Null => return Ok(DataValue::Null),
965 };
966
967 let result = match method.to_lowercase().as_str() {
968 "trim" => str_val.trim().to_string(),
969 "trimstart" => str_val.trim_start().to_string(),
970 "trimend" => str_val.trim_end().to_string(),
971 _ => unreachable!(),
972 };
973
974 Ok(DataValue::String(result))
975 }
976 "length" => {
977 if !args.is_empty() {
978 return Err(anyhow!("Length takes no arguments"));
979 }
980
981 let len = match value {
983 DataValue::String(s) => s.len(),
984 DataValue::InternedString(s) => s.len(),
985 DataValue::Integer(n) => n.to_string().len(),
986 DataValue::Float(f) => f.to_string().len(),
987 DataValue::Boolean(b) => b.to_string().len(),
988 DataValue::DateTime(dt) => dt.len(),
989 DataValue::Null => return Ok(DataValue::Integer(0)),
990 };
991
992 Ok(DataValue::Integer(len as i64))
993 }
994 "indexof" => {
995 if args.len() != 1 {
996 return Err(anyhow!("IndexOf requires exactly 1 argument"));
997 }
998
999 let search_str = match self.evaluate(&args[0], row_index)? {
1001 DataValue::String(s) => s,
1002 DataValue::InternedString(s) => s.to_string(),
1003 DataValue::Integer(n) => n.to_string(),
1004 DataValue::Float(f) => f.to_string(),
1005 _ => return Err(anyhow!("IndexOf argument must be a string")),
1006 };
1007
1008 let str_val = match value {
1010 DataValue::String(s) => s.clone(),
1011 DataValue::InternedString(s) => s.to_string(),
1012 DataValue::Integer(n) => n.to_string(),
1013 DataValue::Float(f) => f.to_string(),
1014 DataValue::Boolean(b) => b.to_string(),
1015 DataValue::DateTime(dt) => dt.clone(),
1016 DataValue::Null => return Ok(DataValue::Integer(-1)),
1017 };
1018
1019 let index = str_val.find(&search_str).map(|i| i as i64).unwrap_or(-1);
1020
1021 Ok(DataValue::Integer(index))
1022 }
1023 "contains" => {
1024 if args.len() != 1 {
1025 return Err(anyhow!("Contains requires exactly 1 argument"));
1026 }
1027
1028 let search_str = match self.evaluate(&args[0], row_index)? {
1030 DataValue::String(s) => s,
1031 DataValue::InternedString(s) => s.to_string(),
1032 DataValue::Integer(n) => n.to_string(),
1033 DataValue::Float(f) => f.to_string(),
1034 _ => return Err(anyhow!("Contains argument must be a string")),
1035 };
1036
1037 let str_val = match value {
1039 DataValue::String(s) => s.clone(),
1040 DataValue::InternedString(s) => s.to_string(),
1041 DataValue::Integer(n) => n.to_string(),
1042 DataValue::Float(f) => f.to_string(),
1043 DataValue::Boolean(b) => b.to_string(),
1044 DataValue::DateTime(dt) => dt.clone(),
1045 DataValue::Null => return Ok(DataValue::Boolean(false)),
1046 };
1047
1048 let result = str_val.to_lowercase().contains(&search_str.to_lowercase());
1050 Ok(DataValue::Boolean(result))
1051 }
1052 "startswith" => {
1053 if args.len() != 1 {
1054 return Err(anyhow!("StartsWith requires exactly 1 argument"));
1055 }
1056
1057 let prefix = match self.evaluate(&args[0], row_index)? {
1059 DataValue::String(s) => s,
1060 DataValue::InternedString(s) => s.to_string(),
1061 DataValue::Integer(n) => n.to_string(),
1062 DataValue::Float(f) => f.to_string(),
1063 _ => return Err(anyhow!("StartsWith argument must be a string")),
1064 };
1065
1066 let str_val = match value {
1068 DataValue::String(s) => s.clone(),
1069 DataValue::InternedString(s) => s.to_string(),
1070 DataValue::Integer(n) => n.to_string(),
1071 DataValue::Float(f) => f.to_string(),
1072 DataValue::Boolean(b) => b.to_string(),
1073 DataValue::DateTime(dt) => dt.clone(),
1074 DataValue::Null => return Ok(DataValue::Boolean(false)),
1075 };
1076
1077 let result = str_val.to_lowercase().starts_with(&prefix.to_lowercase());
1079 Ok(DataValue::Boolean(result))
1080 }
1081 "endswith" => {
1082 if args.len() != 1 {
1083 return Err(anyhow!("EndsWith requires exactly 1 argument"));
1084 }
1085
1086 let suffix = match self.evaluate(&args[0], row_index)? {
1088 DataValue::String(s) => s,
1089 DataValue::InternedString(s) => s.to_string(),
1090 DataValue::Integer(n) => n.to_string(),
1091 DataValue::Float(f) => f.to_string(),
1092 _ => return Err(anyhow!("EndsWith argument must be a string")),
1093 };
1094
1095 let str_val = match value {
1097 DataValue::String(s) => s.clone(),
1098 DataValue::InternedString(s) => s.to_string(),
1099 DataValue::Integer(n) => n.to_string(),
1100 DataValue::Float(f) => f.to_string(),
1101 DataValue::Boolean(b) => b.to_string(),
1102 DataValue::DateTime(dt) => dt.clone(),
1103 DataValue::Null => return Ok(DataValue::Boolean(false)),
1104 };
1105
1106 let result = str_val.to_lowercase().ends_with(&suffix.to_lowercase());
1108 Ok(DataValue::Boolean(result))
1109 }
1110 _ => Err(anyhow!("Unsupported method: {}", method)),
1111 }
1112 }
1113}
1114
1115#[cfg(test)]
1116mod tests {
1117 use super::*;
1118 use crate::data::datatable::{DataColumn, DataRow};
1119
1120 fn create_test_table() -> DataTable {
1121 let mut table = DataTable::new("test");
1122 table.add_column(DataColumn::new("a"));
1123 table.add_column(DataColumn::new("b"));
1124 table.add_column(DataColumn::new("c"));
1125
1126 table
1127 .add_row(DataRow::new(vec![
1128 DataValue::Integer(10),
1129 DataValue::Float(2.5),
1130 DataValue::Integer(4),
1131 ]))
1132 .unwrap();
1133
1134 table
1135 }
1136
1137 #[test]
1138 fn test_evaluate_column() {
1139 let table = create_test_table();
1140 let evaluator = ArithmeticEvaluator::new(&table);
1141
1142 let expr = SqlExpression::Column("a".to_string());
1143 let result = evaluator.evaluate(&expr, 0).unwrap();
1144 assert_eq!(result, DataValue::Integer(10));
1145 }
1146
1147 #[test]
1148 fn test_evaluate_number_literal() {
1149 let table = create_test_table();
1150 let evaluator = ArithmeticEvaluator::new(&table);
1151
1152 let expr = SqlExpression::NumberLiteral("42".to_string());
1153 let result = evaluator.evaluate(&expr, 0).unwrap();
1154 assert_eq!(result, DataValue::Integer(42));
1155
1156 let expr = SqlExpression::NumberLiteral("3.14".to_string());
1157 let result = evaluator.evaluate(&expr, 0).unwrap();
1158 assert_eq!(result, DataValue::Float(3.14));
1159 }
1160
1161 #[test]
1162 fn test_add_values() {
1163 let table = create_test_table();
1164 let evaluator = ArithmeticEvaluator::new(&table);
1165
1166 let result = evaluator
1168 .add_values(&DataValue::Integer(5), &DataValue::Integer(3))
1169 .unwrap();
1170 assert_eq!(result, DataValue::Integer(8));
1171
1172 let result = evaluator
1174 .add_values(&DataValue::Integer(5), &DataValue::Float(2.5))
1175 .unwrap();
1176 assert_eq!(result, DataValue::Float(7.5));
1177 }
1178
1179 #[test]
1180 fn test_multiply_values() {
1181 let table = create_test_table();
1182 let evaluator = ArithmeticEvaluator::new(&table);
1183
1184 let result = evaluator
1186 .multiply_values(&DataValue::Integer(4), &DataValue::Float(2.5))
1187 .unwrap();
1188 assert_eq!(result, DataValue::Float(10.0));
1189 }
1190
1191 #[test]
1192 fn test_divide_values() {
1193 let table = create_test_table();
1194 let evaluator = ArithmeticEvaluator::new(&table);
1195
1196 let result = evaluator
1198 .divide_values(&DataValue::Integer(10), &DataValue::Integer(2))
1199 .unwrap();
1200 assert_eq!(result, DataValue::Integer(5));
1201
1202 let result = evaluator
1204 .divide_values(&DataValue::Integer(10), &DataValue::Integer(3))
1205 .unwrap();
1206 assert_eq!(result, DataValue::Float(10.0 / 3.0));
1207 }
1208
1209 #[test]
1210 fn test_division_by_zero() {
1211 let table = create_test_table();
1212 let evaluator = ArithmeticEvaluator::new(&table);
1213
1214 let result = evaluator.divide_values(&DataValue::Integer(10), &DataValue::Integer(0));
1215 assert!(result.is_err());
1216 assert!(result.unwrap_err().to_string().contains("Division by zero"));
1217 }
1218
1219 #[test]
1220 fn test_binary_op_expression() {
1221 let table = create_test_table();
1222 let evaluator = ArithmeticEvaluator::new(&table);
1223
1224 let expr = SqlExpression::BinaryOp {
1226 left: Box::new(SqlExpression::Column("a".to_string())),
1227 op: "*".to_string(),
1228 right: Box::new(SqlExpression::Column("b".to_string())),
1229 };
1230
1231 let result = evaluator.evaluate(&expr, 0).unwrap();
1232 assert_eq!(result, DataValue::Float(25.0));
1233 }
1234}