sql_cli/data/
arithmetic_evaluator.rs

1use crate::data::datatable::{DataTable, DataValue};
2use crate::sql::recursive_parser::SqlExpression;
3use anyhow::{anyhow, Result};
4use chrono::{DateTime, Datelike, NaiveDate, NaiveDateTime, TimeZone, Utc};
5use tracing::debug;
6
7/// Evaluates SQL expressions to compute DataValues (for SELECT clauses)
8/// This is different from RecursiveWhereEvaluator which returns boolean
9pub struct ArithmeticEvaluator<'a> {
10    table: &'a DataTable,
11}
12
13impl<'a> ArithmeticEvaluator<'a> {
14    pub fn new(table: &'a DataTable) -> Self {
15        Self { table }
16    }
17
18    /// Find a column name similar to the given name using edit distance
19    fn find_similar_column(&self, name: &str) -> Option<String> {
20        let columns = self.table.column_names();
21        let mut best_match: Option<(String, usize)> = None;
22
23        for col in columns {
24            let distance = self.edit_distance(&col.to_lowercase(), &name.to_lowercase());
25            // Only suggest if distance is small (likely a typo)
26            // Allow up to 3 edits for longer names
27            let max_distance = if name.len() > 10 { 3 } else { 2 };
28            if distance <= max_distance {
29                match &best_match {
30                    None => best_match = Some((col, distance)),
31                    Some((_, best_dist)) if distance < *best_dist => {
32                        best_match = Some((col, distance));
33                    }
34                    _ => {}
35                }
36            }
37        }
38
39        best_match.map(|(name, _)| name)
40    }
41
42    /// Calculate Levenshtein edit distance between two strings
43    fn edit_distance(&self, s1: &str, s2: &str) -> usize {
44        let len1 = s1.len();
45        let len2 = s2.len();
46        let mut matrix = vec![vec![0; len2 + 1]; len1 + 1];
47
48        for i in 0..=len1 {
49            matrix[i][0] = i;
50        }
51        for j in 0..=len2 {
52            matrix[0][j] = j;
53        }
54
55        for (i, c1) in s1.chars().enumerate() {
56            for (j, c2) in s2.chars().enumerate() {
57                let cost = if c1 == c2 { 0 } else { 1 };
58                matrix[i + 1][j + 1] = std::cmp::min(
59                    matrix[i][j + 1] + 1, // deletion
60                    std::cmp::min(
61                        matrix[i + 1][j] + 1, // insertion
62                        matrix[i][j] + cost,  // substitution
63                    ),
64                );
65            }
66        }
67
68        matrix[len1][len2]
69    }
70
71    /// Evaluate an SQL expression to produce a DataValue
72    pub fn evaluate(&self, expr: &SqlExpression, row_index: usize) -> Result<DataValue> {
73        debug!(
74            "ArithmeticEvaluator: evaluating {:?} for row {}",
75            expr, row_index
76        );
77
78        match expr {
79            SqlExpression::Column(column_name) => self.evaluate_column(column_name, row_index),
80            SqlExpression::StringLiteral(s) => Ok(DataValue::String(s.clone())),
81            SqlExpression::NumberLiteral(n) => self.evaluate_number_literal(n),
82            SqlExpression::BinaryOp { left, op, right } => {
83                self.evaluate_binary_op(left, op, right, row_index)
84            }
85            SqlExpression::FunctionCall { name, args } => {
86                self.evaluate_function(name, args, row_index)
87            }
88            _ => Err(anyhow!(
89                "Unsupported expression type for arithmetic evaluation: {:?}",
90                expr
91            )),
92        }
93    }
94
95    /// Evaluate a column reference
96    fn evaluate_column(&self, column_name: &str, row_index: usize) -> Result<DataValue> {
97        let col_index = self.table.get_column_index(column_name).ok_or_else(|| {
98            let suggestion = self.find_similar_column(column_name);
99            match suggestion {
100                Some(similar) => anyhow!(
101                    "Column '{}' not found. Did you mean '{}'?",
102                    column_name,
103                    similar
104                ),
105                None => anyhow!("Column '{}' not found", column_name),
106            }
107        })?;
108
109        if row_index >= self.table.row_count() {
110            return Err(anyhow!("Row index {} out of bounds", row_index));
111        }
112
113        let row = self
114            .table
115            .get_row(row_index)
116            .ok_or_else(|| anyhow!("Row {} not found", row_index))?;
117
118        let value = row
119            .get(col_index)
120            .ok_or_else(|| anyhow!("Column index {} out of bounds for row", col_index))?;
121
122        Ok(value.clone())
123    }
124
125    /// Evaluate a number literal (handles both integers and floats)
126    fn evaluate_number_literal(&self, number_str: &str) -> Result<DataValue> {
127        // Try to parse as integer first
128        if let Ok(int_val) = number_str.parse::<i64>() {
129            return Ok(DataValue::Integer(int_val));
130        }
131
132        // If that fails, try as float
133        if let Ok(float_val) = number_str.parse::<f64>() {
134            return Ok(DataValue::Float(float_val));
135        }
136
137        Err(anyhow!("Invalid number literal: {}", number_str))
138    }
139
140    /// Evaluate a binary operation (arithmetic)
141    fn evaluate_binary_op(
142        &self,
143        left: &SqlExpression,
144        op: &str,
145        right: &SqlExpression,
146        row_index: usize,
147    ) -> Result<DataValue> {
148        let left_val = self.evaluate(left, row_index)?;
149        let right_val = self.evaluate(right, row_index)?;
150
151        debug!(
152            "ArithmeticEvaluator: {} {} {}",
153            self.format_value(&left_val),
154            op,
155            self.format_value(&right_val)
156        );
157
158        match op {
159            "+" => self.add_values(&left_val, &right_val),
160            "-" => self.subtract_values(&left_val, &right_val),
161            "*" => self.multiply_values(&left_val, &right_val),
162            "/" => self.divide_values(&left_val, &right_val),
163            _ => Err(anyhow!("Unsupported arithmetic operator: {}", op)),
164        }
165    }
166
167    /// Add two DataValues with type coercion
168    fn add_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
169        match (left, right) {
170            (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a + b)),
171            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 + b)),
172            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a + *b as f64)),
173            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a + b)),
174            _ => Err(anyhow!("Cannot add {:?} and {:?}", left, right)),
175        }
176    }
177
178    /// Subtract two DataValues with type coercion
179    fn subtract_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
180        match (left, right) {
181            (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a - b)),
182            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 - b)),
183            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a - *b as f64)),
184            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a - b)),
185            _ => Err(anyhow!("Cannot subtract {:?} and {:?}", left, right)),
186        }
187    }
188
189    /// Multiply two DataValues with type coercion
190    fn multiply_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
191        match (left, right) {
192            (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a * b)),
193            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 * b)),
194            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a * *b as f64)),
195            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a * b)),
196            _ => Err(anyhow!("Cannot multiply {:?} and {:?}", left, right)),
197        }
198    }
199
200    /// Divide two DataValues with type coercion
201    fn divide_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
202        // Check for division by zero first
203        let is_zero = match right {
204            DataValue::Integer(0) => true,
205            DataValue::Float(f) if f.abs() < f64::EPSILON => true,
206            _ => false,
207        };
208
209        if is_zero {
210            return Err(anyhow!("Division by zero"));
211        }
212
213        match (left, right) {
214            (DataValue::Integer(a), DataValue::Integer(b)) => {
215                // Integer division - if result is exact, keep as int, otherwise promote to float
216                if a % b == 0 {
217                    Ok(DataValue::Integer(a / b))
218                } else {
219                    Ok(DataValue::Float(*a as f64 / *b as f64))
220                }
221            }
222            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 / b)),
223            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a / *b as f64)),
224            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a / b)),
225            _ => Err(anyhow!("Cannot divide {:?} and {:?}", left, right)),
226        }
227    }
228
229    /// Format a DataValue for debug output
230    fn format_value(&self, value: &DataValue) -> String {
231        match value {
232            DataValue::Integer(i) => i.to_string(),
233            DataValue::Float(f) => f.to_string(),
234            DataValue::String(s) => format!("'{}'", s),
235            _ => format!("{:?}", value),
236        }
237    }
238
239    /// Evaluate a function call
240    fn evaluate_function(
241        &self,
242        name: &str,
243        args: &[SqlExpression],
244        row_index: usize,
245    ) -> Result<DataValue> {
246        match name {
247            "ROUND" => {
248                if args.is_empty() || args.len() > 2 {
249                    return Err(anyhow!("ROUND requires 1 or 2 arguments"));
250                }
251
252                // Evaluate the value to round
253                let value = self.evaluate(&args[0], row_index)?;
254
255                // Get decimal places (default to 0 if not specified)
256                let decimals = if args.len() == 2 {
257                    match self.evaluate(&args[1], row_index)? {
258                        DataValue::Integer(n) => n as i32,
259                        DataValue::Float(f) => f as i32,
260                        _ => return Err(anyhow!("ROUND precision must be a number")),
261                    }
262                } else {
263                    0
264                };
265
266                // Perform rounding
267                match value {
268                    DataValue::Integer(n) => Ok(DataValue::Integer(n)), // Already an integer
269                    DataValue::Float(f) => {
270                        if decimals >= 0 {
271                            let multiplier = 10_f64.powi(decimals);
272                            let rounded = (f * multiplier).round() / multiplier;
273                            if decimals == 0 {
274                                // Return as integer if rounding to 0 decimals
275                                Ok(DataValue::Integer(rounded as i64))
276                            } else {
277                                Ok(DataValue::Float(rounded))
278                            }
279                        } else {
280                            // Negative decimals round to left of decimal point
281                            let divisor = 10_f64.powi(-decimals);
282                            let rounded = (f / divisor).round() * divisor;
283                            Ok(DataValue::Float(rounded))
284                        }
285                    }
286                    _ => Err(anyhow!("ROUND can only be applied to numeric values")),
287                }
288            }
289            "ABS" => {
290                if args.len() != 1 {
291                    return Err(anyhow!("ABS requires exactly 1 argument"));
292                }
293
294                let value = self.evaluate(&args[0], row_index)?;
295                match value {
296                    DataValue::Integer(n) => Ok(DataValue::Integer(n.abs())),
297                    DataValue::Float(f) => Ok(DataValue::Float(f.abs())),
298                    _ => Err(anyhow!("ABS can only be applied to numeric values")),
299                }
300            }
301            "FLOOR" => {
302                if args.len() != 1 {
303                    return Err(anyhow!("FLOOR requires exactly 1 argument"));
304                }
305
306                let value = self.evaluate(&args[0], row_index)?;
307                match value {
308                    DataValue::Integer(n) => Ok(DataValue::Integer(n)),
309                    DataValue::Float(f) => Ok(DataValue::Integer(f.floor() as i64)),
310                    _ => Err(anyhow!("FLOOR can only be applied to numeric values")),
311                }
312            }
313            "CEILING" | "CEIL" => {
314                if args.len() != 1 {
315                    return Err(anyhow!("CEILING requires exactly 1 argument"));
316                }
317
318                let value = self.evaluate(&args[0], row_index)?;
319                match value {
320                    DataValue::Integer(n) => Ok(DataValue::Integer(n)),
321                    DataValue::Float(f) => Ok(DataValue::Integer(f.ceil() as i64)),
322                    _ => Err(anyhow!("CEILING can only be applied to numeric values")),
323                }
324            }
325            "MOD" => {
326                if args.len() != 2 {
327                    return Err(anyhow!("MOD requires exactly 2 arguments"));
328                }
329
330                let dividend = self.evaluate(&args[0], row_index)?;
331                let divisor = self.evaluate(&args[1], row_index)?;
332
333                match (&dividend, &divisor) {
334                    (DataValue::Integer(n), DataValue::Integer(d)) => {
335                        if *d == 0 {
336                            return Err(anyhow!("Division by zero in MOD"));
337                        }
338                        Ok(DataValue::Integer(n % d))
339                    }
340                    _ => {
341                        // Convert to float for mixed types
342                        let n = match dividend {
343                            DataValue::Integer(i) => i as f64,
344                            DataValue::Float(f) => f,
345                            _ => return Err(anyhow!("MOD requires numeric arguments")),
346                        };
347                        let d = match divisor {
348                            DataValue::Integer(i) => i as f64,
349                            DataValue::Float(f) => f,
350                            _ => return Err(anyhow!("MOD requires numeric arguments")),
351                        };
352                        if d == 0.0 {
353                            return Err(anyhow!("Division by zero in MOD"));
354                        }
355                        Ok(DataValue::Float(n % d))
356                    }
357                }
358            }
359            "QUOTIENT" => {
360                if args.len() != 2 {
361                    return Err(anyhow!("QUOTIENT requires exactly 2 arguments"));
362                }
363
364                let numerator = self.evaluate(&args[0], row_index)?;
365                let denominator = self.evaluate(&args[1], row_index)?;
366
367                match (&numerator, &denominator) {
368                    (DataValue::Integer(n), DataValue::Integer(d)) => {
369                        if *d == 0 {
370                            return Err(anyhow!("Division by zero in QUOTIENT"));
371                        }
372                        Ok(DataValue::Integer(n / d))
373                    }
374                    _ => {
375                        // Convert to float for mixed types
376                        let n = match numerator {
377                            DataValue::Integer(i) => i as f64,
378                            DataValue::Float(f) => f,
379                            _ => return Err(anyhow!("QUOTIENT requires numeric arguments")),
380                        };
381                        let d = match denominator {
382                            DataValue::Integer(i) => i as f64,
383                            DataValue::Float(f) => f,
384                            _ => return Err(anyhow!("QUOTIENT requires numeric arguments")),
385                        };
386                        if d == 0.0 {
387                            return Err(anyhow!("Division by zero in QUOTIENT"));
388                        }
389                        Ok(DataValue::Integer((n / d).trunc() as i64))
390                    }
391                }
392            }
393            "POWER" | "POW" => {
394                if args.len() != 2 {
395                    return Err(anyhow!("POWER requires exactly 2 arguments"));
396                }
397
398                let base = self.evaluate(&args[0], row_index)?;
399                let exponent = self.evaluate(&args[1], row_index)?;
400
401                match (&base, &exponent) {
402                    (DataValue::Integer(b), DataValue::Integer(e)) => {
403                        if *e >= 0 && *e <= i32::MAX as i64 {
404                            Ok(DataValue::Float((*b as f64).powi(*e as i32)))
405                        } else {
406                            Ok(DataValue::Float((*b as f64).powf(*e as f64)))
407                        }
408                    }
409                    _ => {
410                        // Convert to float for mixed types or floats
411                        let b = match base {
412                            DataValue::Integer(i) => i as f64,
413                            DataValue::Float(f) => f,
414                            _ => return Err(anyhow!("POWER requires numeric arguments")),
415                        };
416                        let e = match exponent {
417                            DataValue::Integer(i) => i as f64,
418                            DataValue::Float(f) => f,
419                            _ => return Err(anyhow!("POWER requires numeric arguments")),
420                        };
421                        Ok(DataValue::Float(b.powf(e)))
422                    }
423                }
424            }
425            "SQRT" => {
426                if args.len() != 1 {
427                    return Err(anyhow!("SQRT requires exactly 1 argument"));
428                }
429
430                let value = self.evaluate(&args[0], row_index)?;
431                match value {
432                    DataValue::Integer(n) => {
433                        if n < 0 {
434                            return Err(anyhow!("SQRT of negative number"));
435                        }
436                        Ok(DataValue::Float((n as f64).sqrt()))
437                    }
438                    DataValue::Float(f) => {
439                        if f < 0.0 {
440                            return Err(anyhow!("SQRT of negative number"));
441                        }
442                        Ok(DataValue::Float(f.sqrt()))
443                    }
444                    _ => Err(anyhow!("SQRT can only be applied to numeric values")),
445                }
446            }
447            "EXP" => {
448                if args.len() != 1 {
449                    return Err(anyhow!("EXP requires exactly 1 argument"));
450                }
451
452                let value = self.evaluate(&args[0], row_index)?;
453                match value {
454                    DataValue::Integer(n) => Ok(DataValue::Float((n as f64).exp())),
455                    DataValue::Float(f) => Ok(DataValue::Float(f.exp())),
456                    _ => Err(anyhow!("EXP can only be applied to numeric values")),
457                }
458            }
459            "LN" => {
460                if args.len() != 1 {
461                    return Err(anyhow!("LN requires exactly 1 argument"));
462                }
463
464                let value = self.evaluate(&args[0], row_index)?;
465                match value {
466                    DataValue::Integer(n) => {
467                        if n <= 0 {
468                            return Err(anyhow!("LN of non-positive number"));
469                        }
470                        Ok(DataValue::Float((n as f64).ln()))
471                    }
472                    DataValue::Float(f) => {
473                        if f <= 0.0 {
474                            return Err(anyhow!("LN of non-positive number"));
475                        }
476                        Ok(DataValue::Float(f.ln()))
477                    }
478                    _ => Err(anyhow!("LN can only be applied to numeric values")),
479                }
480            }
481            "LOG" | "LOG10" => {
482                if name == "LOG" && args.len() == 2 {
483                    // LOG with custom base
484                    let value = self.evaluate(&args[0], row_index)?;
485                    let base = self.evaluate(&args[1], row_index)?;
486
487                    let n = match value {
488                        DataValue::Integer(i) => i as f64,
489                        DataValue::Float(f) => f,
490                        _ => return Err(anyhow!("LOG requires numeric arguments")),
491                    };
492                    let b = match base {
493                        DataValue::Integer(i) => i as f64,
494                        DataValue::Float(f) => f,
495                        _ => return Err(anyhow!("LOG requires numeric arguments")),
496                    };
497
498                    if n <= 0.0 {
499                        return Err(anyhow!("LOG of non-positive number"));
500                    }
501                    if b <= 0.0 || b == 1.0 {
502                        return Err(anyhow!("Invalid LOG base"));
503                    }
504                    Ok(DataValue::Float(n.log(b)))
505                } else if (name == "LOG" && args.len() == 1) || name == "LOG10" {
506                    // LOG10 or LOG with default base 10
507                    if args.len() != 1 {
508                        return Err(anyhow!("{} requires exactly 1 argument", name));
509                    }
510
511                    let value = self.evaluate(&args[0], row_index)?;
512                    match value {
513                        DataValue::Integer(n) => {
514                            if n <= 0 {
515                                return Err(anyhow!("LOG10 of non-positive number"));
516                            }
517                            Ok(DataValue::Float((n as f64).log10()))
518                        }
519                        DataValue::Float(f) => {
520                            if f <= 0.0 {
521                                return Err(anyhow!("LOG10 of non-positive number"));
522                            }
523                            Ok(DataValue::Float(f.log10()))
524                        }
525                        _ => Err(anyhow!("LOG10 can only be applied to numeric values")),
526                    }
527                } else {
528                    Err(anyhow!("LOG requires 1 or 2 arguments"))
529                }
530            }
531            "PI" => {
532                if !args.is_empty() {
533                    return Err(anyhow!("PI takes no arguments"));
534                }
535                Ok(DataValue::Float(std::f64::consts::PI))
536            }
537            "DATEDIFF" => {
538                if args.len() != 3 {
539                    return Err(anyhow!(
540                        "DATEDIFF requires exactly 3 arguments: unit, date1, date2"
541                    ));
542                }
543
544                // First argument: unit (day, month, year, hour, minute, second)
545                let unit = match self.evaluate(&args[0], row_index)? {
546                    DataValue::String(s) => s.to_lowercase(),
547                    DataValue::InternedString(s) => s.to_lowercase(),
548                    _ => return Err(anyhow!("DATEDIFF unit must be a string")),
549                };
550
551                // Helper function to parse date/datetime strings
552                let parse_datetime = |value: DataValue| -> Result<DateTime<Utc>> {
553                    let parse_string = |s: &str| -> Result<DateTime<Utc>> {
554                        // Try various date/datetime formats
555
556                        // ISO formats (most common)
557                        if let Ok(dt) = NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S") {
558                            return Ok(Utc.from_utc_datetime(&dt));
559                        }
560                        if let Ok(dt) = NaiveDate::parse_from_str(s, "%Y-%m-%d") {
561                            return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
562                        }
563
564                        // US format: MM/DD/YYYY or MM-DD-YYYY
565                        if let Ok(dt) = NaiveDate::parse_from_str(s, "%m/%d/%Y") {
566                            return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
567                        }
568                        if let Ok(dt) = NaiveDate::parse_from_str(s, "%m-%d-%Y") {
569                            return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
570                        }
571
572                        // European format: DD/MM/YYYY or DD-MM-YYYY
573                        if let Ok(dt) = NaiveDate::parse_from_str(s, "%d/%m/%Y") {
574                            return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
575                        }
576                        if let Ok(dt) = NaiveDate::parse_from_str(s, "%d-%m-%Y") {
577                            return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
578                        }
579
580                        // Excel/Windows format: DD-MMM-YYYY (e.g., 15-Jan-2024)
581                        if let Ok(dt) = NaiveDate::parse_from_str(s, "%d-%b-%Y") {
582                            return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
583                        }
584
585                        // Full month names: January 15, 2024 or 15 January 2024
586                        if let Ok(dt) = NaiveDate::parse_from_str(s, "%B %d, %Y") {
587                            return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
588                        }
589                        if let Ok(dt) = NaiveDate::parse_from_str(s, "%d %B %Y") {
590                            return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
591                        }
592
593                        // With time: MM/DD/YYYY HH:MM:SS
594                        if let Ok(dt) = NaiveDateTime::parse_from_str(s, "%m/%d/%Y %H:%M:%S") {
595                            return Ok(Utc.from_utc_datetime(&dt));
596                        }
597                        if let Ok(dt) = NaiveDateTime::parse_from_str(s, "%d/%m/%Y %H:%M:%S") {
598                            return Ok(Utc.from_utc_datetime(&dt));
599                        }
600
601                        // ISO 8601 / RFC3339
602                        if let Ok(dt) = s.parse::<DateTime<Utc>>() {
603                            return Ok(dt);
604                        }
605
606                        Err(anyhow!("Could not parse date: {}. Supported formats: YYYY-MM-DD, MM/DD/YYYY, DD/MM/YYYY, DD-MMM-YYYY", s))
607                    };
608
609                    match value {
610                        DataValue::String(s) | DataValue::DateTime(s) => parse_string(&s),
611                        DataValue::InternedString(s) => parse_string(s.as_str()),
612                        _ => Err(anyhow!("DATEDIFF requires date/datetime values")),
613                    }
614                };
615
616                // Parse both dates
617                let date1 = parse_datetime(self.evaluate(&args[1], row_index)?)?;
618                let date2 = parse_datetime(self.evaluate(&args[2], row_index)?)?;
619
620                // Calculate difference based on unit
621                let diff = match unit.as_str() {
622                    "day" | "days" => {
623                        let duration = date2.signed_duration_since(date1);
624                        duration.num_days()
625                    }
626                    "month" | "months" => {
627                        // Approximate months as 30.44 days
628                        let duration = date2.signed_duration_since(date1);
629                        duration.num_days() / 30
630                    }
631                    "year" | "years" => {
632                        // Approximate years as 365.25 days
633                        let duration = date2.signed_duration_since(date1);
634                        duration.num_days() / 365
635                    }
636                    "hour" | "hours" => {
637                        let duration = date2.signed_duration_since(date1);
638                        duration.num_hours()
639                    }
640                    "minute" | "minutes" => {
641                        let duration = date2.signed_duration_since(date1);
642                        duration.num_minutes()
643                    }
644                    "second" | "seconds" => {
645                        let duration = date2.signed_duration_since(date1);
646                        duration.num_seconds()
647                    }
648                    _ => {
649                        return Err(anyhow!(
650                        "Unknown DATEDIFF unit: {}. Use: day, month, year, hour, minute, second",
651                        unit
652                    ))
653                    }
654                };
655
656                Ok(DataValue::Integer(diff))
657            }
658            "NOW" => {
659                if !args.is_empty() {
660                    return Err(anyhow!("NOW takes no arguments"));
661                }
662                let now = Utc::now();
663                Ok(DataValue::DateTime(
664                    now.format("%Y-%m-%d %H:%M:%S").to_string(),
665                ))
666            }
667            "TODAY" => {
668                if !args.is_empty() {
669                    return Err(anyhow!("TODAY takes no arguments"));
670                }
671                let today = Utc::now().date_naive();
672                Ok(DataValue::String(today.format("%Y-%m-%d").to_string()))
673            }
674            "DATEADD" => {
675                if args.len() != 3 {
676                    return Err(anyhow!(
677                        "DATEADD requires exactly 3 arguments: unit, number, date"
678                    ));
679                }
680
681                // First argument: unit (day, month, year, hour, minute, second)
682                let unit = match self.evaluate(&args[0], row_index)? {
683                    DataValue::String(s) => s.to_lowercase(),
684                    DataValue::InternedString(s) => s.to_lowercase(),
685                    _ => return Err(anyhow!("DATEADD unit must be a string")),
686                };
687
688                // Second argument: number to add (can be negative for subtraction)
689                let amount = match self.evaluate(&args[1], row_index)? {
690                    DataValue::Integer(i) => i,
691                    DataValue::Float(f) => f as i64,
692                    _ => return Err(anyhow!("DATEADD amount must be a number")),
693                };
694
695                // Third argument: base date
696                let base_date_value = self.evaluate(&args[2], row_index)?;
697
698                // Reuse the parse_datetime function from DATEDIFF
699                let parse_datetime = |value: DataValue| -> Result<DateTime<Utc>> {
700                    let parse_string = |s: &str| -> Result<DateTime<Utc>> {
701                        // Try various date/datetime formats (same as DATEDIFF)
702
703                        // ISO formats (most common)
704                        if let Ok(dt) = NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S") {
705                            return Ok(Utc.from_utc_datetime(&dt));
706                        }
707                        if let Ok(dt) = NaiveDate::parse_from_str(s, "%Y-%m-%d") {
708                            return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
709                        }
710
711                        // US format: MM/DD/YYYY or MM-DD-YYYY
712                        if let Ok(dt) = NaiveDate::parse_from_str(s, "%m/%d/%Y") {
713                            return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
714                        }
715                        if let Ok(dt) = NaiveDate::parse_from_str(s, "%m-%d-%Y") {
716                            return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
717                        }
718
719                        // European format: DD/MM/YYYY or DD-MM-YYYY
720                        if let Ok(dt) = NaiveDate::parse_from_str(s, "%d/%m/%Y") {
721                            return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
722                        }
723                        if let Ok(dt) = NaiveDate::parse_from_str(s, "%d-%m-%Y") {
724                            return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
725                        }
726
727                        // Excel/Windows format: DD-MMM-YYYY (e.g., 15-Jan-2024)
728                        if let Ok(dt) = NaiveDate::parse_from_str(s, "%d-%b-%Y") {
729                            return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
730                        }
731
732                        // Full month names: January 15, 2024 or 15 January 2024
733                        if let Ok(dt) = NaiveDate::parse_from_str(s, "%B %d, %Y") {
734                            return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
735                        }
736                        if let Ok(dt) = NaiveDate::parse_from_str(s, "%d %B %Y") {
737                            return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
738                        }
739
740                        // With time: MM/DD/YYYY HH:MM:SS
741                        if let Ok(dt) = NaiveDateTime::parse_from_str(s, "%m/%d/%Y %H:%M:%S") {
742                            return Ok(Utc.from_utc_datetime(&dt));
743                        }
744                        if let Ok(dt) = NaiveDateTime::parse_from_str(s, "%d/%m/%Y %H:%M:%S") {
745                            return Ok(Utc.from_utc_datetime(&dt));
746                        }
747
748                        // ISO 8601 / RFC3339
749                        if let Ok(dt) = s.parse::<DateTime<Utc>>() {
750                            return Ok(dt);
751                        }
752
753                        Err(anyhow!("Could not parse date: {}. Supported formats: YYYY-MM-DD, MM/DD/YYYY, DD/MM/YYYY, DD-MMM-YYYY", s))
754                    };
755
756                    match value {
757                        DataValue::String(s) | DataValue::DateTime(s) => parse_string(&s),
758                        DataValue::InternedString(s) => parse_string(s.as_str()),
759                        _ => Err(anyhow!("DATEADD requires date/datetime values")),
760                    }
761                };
762
763                // Parse the base date
764                let base_date = parse_datetime(base_date_value)?;
765
766                // Add the specified amount based on unit
767                let result_date = match unit.as_str() {
768                    "day" | "days" => base_date + chrono::Duration::days(amount),
769                    "month" | "months" => {
770                        // For months, we need to be careful about month boundaries
771                        let mut year = base_date.year();
772                        let mut month = base_date.month() as i32;
773                        let day = base_date.day();
774
775                        month += amount as i32;
776
777                        // Handle month overflow/underflow
778                        while month > 12 {
779                            month -= 12;
780                            year += 1;
781                        }
782                        while month < 1 {
783                            month += 12;
784                            year -= 1;
785                        }
786
787                        // Create new date, handling day overflow (e.g., Jan 31 + 1 month = Feb 28/29)
788                        let target_date = NaiveDate::from_ymd_opt(year, month as u32, day)
789                            .unwrap_or_else(|| {
790                                // If day doesn't exist in target month, use the last day of that month
791                                // Try decreasing days until we find a valid one
792                                for test_day in (1..=day).rev() {
793                                    if let Some(date) =
794                                        NaiveDate::from_ymd_opt(year, month as u32, test_day)
795                                    {
796                                        return date;
797                                    }
798                                }
799                                // This should never happen, but fallback to day 28 as safety
800                                NaiveDate::from_ymd_opt(year, month as u32, 28).unwrap()
801                            });
802
803                        Utc.from_utc_datetime(&target_date.and_time(base_date.time()))
804                    }
805                    "year" | "years" => {
806                        let new_year = base_date.year() + amount as i32;
807                        let target_date =
808                            NaiveDate::from_ymd_opt(new_year, base_date.month(), base_date.day())
809                                .unwrap_or_else(|| {
810                                    // Handle Feb 29 in non-leap years
811                                    NaiveDate::from_ymd_opt(new_year, base_date.month(), 28)
812                                        .unwrap()
813                                });
814                        Utc.from_utc_datetime(&target_date.and_time(base_date.time()))
815                    }
816                    "hour" | "hours" => base_date + chrono::Duration::hours(amount),
817                    "minute" | "minutes" => base_date + chrono::Duration::minutes(amount),
818                    "second" | "seconds" => base_date + chrono::Duration::seconds(amount),
819                    _ => {
820                        return Err(anyhow!(
821                            "Unknown DATEADD unit: {}. Use: day, month, year, hour, minute, second",
822                            unit
823                        ))
824                    }
825                };
826
827                // Return as datetime string
828                Ok(DataValue::DateTime(
829                    result_date.format("%Y-%m-%d %H:%M:%S").to_string(),
830                ))
831            }
832            "TEXTJOIN" => {
833                if args.len() < 3 {
834                    return Err(anyhow!("TEXTJOIN requires at least 3 arguments: delimiter, ignore_empty, text1, [text2, ...]"));
835                }
836
837                // First argument: delimiter
838                let delimiter = match self.evaluate(&args[0], row_index)? {
839                    DataValue::String(s) => s,
840                    DataValue::InternedString(s) => s.to_string(),
841                    DataValue::Integer(n) => n.to_string(),
842                    DataValue::Float(f) => f.to_string(),
843                    DataValue::Boolean(b) => b.to_string(),
844                    DataValue::Null => String::new(),
845                    _ => String::new(),
846                };
847
848                // Second argument: ignore_empty (treat as boolean - 0 is false, anything else is true)
849                let ignore_empty = match self.evaluate(&args[1], row_index)? {
850                    DataValue::Integer(n) => n != 0,
851                    DataValue::Float(f) => f != 0.0,
852                    DataValue::Boolean(b) => b,
853                    DataValue::String(s) => {
854                        !s.is_empty() && s != "0" && s.to_lowercase() != "false"
855                    }
856                    DataValue::InternedString(s) => {
857                        !s.is_empty() && s.as_str() != "0" && s.to_lowercase() != "false"
858                    }
859                    DataValue::Null => false,
860                    _ => true,
861                };
862
863                // Remaining arguments: values to join
864                let mut values = Vec::new();
865                for i in 2..args.len() {
866                    let value = self.evaluate(&args[i], row_index)?;
867                    let string_value = match value {
868                        DataValue::String(s) => Some(s),
869                        DataValue::InternedString(s) => Some(s.to_string()),
870                        DataValue::Integer(n) => Some(n.to_string()),
871                        DataValue::Float(f) => Some(f.to_string()),
872                        DataValue::Boolean(b) => Some(b.to_string()),
873                        DataValue::DateTime(dt) => Some(dt),
874                        DataValue::Null => {
875                            if ignore_empty {
876                                None
877                            } else {
878                                Some(String::new())
879                            }
880                        }
881                        _ => {
882                            if ignore_empty {
883                                None
884                            } else {
885                                Some(String::new())
886                            }
887                        }
888                    };
889
890                    if let Some(s) = string_value {
891                        if !ignore_empty || !s.is_empty() {
892                            values.push(s);
893                        }
894                    }
895                }
896
897                Ok(DataValue::String(values.join(&delimiter)))
898            }
899            _ => Err(anyhow!("Unknown function: {}", name)),
900        }
901    }
902}
903
904#[cfg(test)]
905mod tests {
906    use super::*;
907    use crate::data::datatable::{DataColumn, DataRow};
908
909    fn create_test_table() -> DataTable {
910        let mut table = DataTable::new("test");
911        table.add_column(DataColumn::new("a"));
912        table.add_column(DataColumn::new("b"));
913        table.add_column(DataColumn::new("c"));
914
915        table
916            .add_row(DataRow::new(vec![
917                DataValue::Integer(10),
918                DataValue::Float(2.5),
919                DataValue::Integer(4),
920            ]))
921            .unwrap();
922
923        table
924    }
925
926    #[test]
927    fn test_evaluate_column() {
928        let table = create_test_table();
929        let evaluator = ArithmeticEvaluator::new(&table);
930
931        let expr = SqlExpression::Column("a".to_string());
932        let result = evaluator.evaluate(&expr, 0).unwrap();
933        assert_eq!(result, DataValue::Integer(10));
934    }
935
936    #[test]
937    fn test_evaluate_number_literal() {
938        let table = create_test_table();
939        let evaluator = ArithmeticEvaluator::new(&table);
940
941        let expr = SqlExpression::NumberLiteral("42".to_string());
942        let result = evaluator.evaluate(&expr, 0).unwrap();
943        assert_eq!(result, DataValue::Integer(42));
944
945        let expr = SqlExpression::NumberLiteral("3.14".to_string());
946        let result = evaluator.evaluate(&expr, 0).unwrap();
947        assert_eq!(result, DataValue::Float(3.14));
948    }
949
950    #[test]
951    fn test_add_values() {
952        let table = create_test_table();
953        let evaluator = ArithmeticEvaluator::new(&table);
954
955        // Integer + Integer
956        let result = evaluator
957            .add_values(&DataValue::Integer(5), &DataValue::Integer(3))
958            .unwrap();
959        assert_eq!(result, DataValue::Integer(8));
960
961        // Integer + Float
962        let result = evaluator
963            .add_values(&DataValue::Integer(5), &DataValue::Float(2.5))
964            .unwrap();
965        assert_eq!(result, DataValue::Float(7.5));
966    }
967
968    #[test]
969    fn test_multiply_values() {
970        let table = create_test_table();
971        let evaluator = ArithmeticEvaluator::new(&table);
972
973        // Integer * Float
974        let result = evaluator
975            .multiply_values(&DataValue::Integer(4), &DataValue::Float(2.5))
976            .unwrap();
977        assert_eq!(result, DataValue::Float(10.0));
978    }
979
980    #[test]
981    fn test_divide_values() {
982        let table = create_test_table();
983        let evaluator = ArithmeticEvaluator::new(&table);
984
985        // Exact division
986        let result = evaluator
987            .divide_values(&DataValue::Integer(10), &DataValue::Integer(2))
988            .unwrap();
989        assert_eq!(result, DataValue::Integer(5));
990
991        // Non-exact division
992        let result = evaluator
993            .divide_values(&DataValue::Integer(10), &DataValue::Integer(3))
994            .unwrap();
995        assert_eq!(result, DataValue::Float(10.0 / 3.0));
996    }
997
998    #[test]
999    fn test_division_by_zero() {
1000        let table = create_test_table();
1001        let evaluator = ArithmeticEvaluator::new(&table);
1002
1003        let result = evaluator.divide_values(&DataValue::Integer(10), &DataValue::Integer(0));
1004        assert!(result.is_err());
1005        assert!(result.unwrap_err().to_string().contains("Division by zero"));
1006    }
1007
1008    #[test]
1009    fn test_binary_op_expression() {
1010        let table = create_test_table();
1011        let evaluator = ArithmeticEvaluator::new(&table);
1012
1013        // a * b where a=10, b=2.5
1014        let expr = SqlExpression::BinaryOp {
1015            left: Box::new(SqlExpression::Column("a".to_string())),
1016            op: "*".to_string(),
1017            right: Box::new(SqlExpression::Column("b".to_string())),
1018        };
1019
1020        let result = evaluator.evaluate(&expr, 0).unwrap();
1021        assert_eq!(result, DataValue::Float(25.0));
1022    }
1023}