sql_cli/data/
arithmetic_evaluator.rs

1use crate::data::datatable::{DataTable, DataValue};
2use crate::sql::recursive_parser::SqlExpression;
3use anyhow::{anyhow, Result};
4use chrono::{DateTime, Datelike, NaiveDate, NaiveDateTime, TimeZone, Utc};
5use tracing::debug;
6
7/// Evaluates SQL expressions to compute DataValues (for SELECT clauses)
8/// This is different from RecursiveWhereEvaluator which returns boolean
9pub struct ArithmeticEvaluator<'a> {
10    table: &'a DataTable,
11}
12
13impl<'a> ArithmeticEvaluator<'a> {
14    pub fn new(table: &'a DataTable) -> Self {
15        Self { table }
16    }
17
18    /// Find a column name similar to the given name using edit distance
19    fn find_similar_column(&self, name: &str) -> Option<String> {
20        let columns = self.table.column_names();
21        let mut best_match: Option<(String, usize)> = None;
22
23        for col in columns {
24            let distance = self.edit_distance(&col.to_lowercase(), &name.to_lowercase());
25            // Only suggest if distance is small (likely a typo)
26            // Allow up to 3 edits for longer names
27            let max_distance = if name.len() > 10 { 3 } else { 2 };
28            if distance <= max_distance {
29                match &best_match {
30                    None => best_match = Some((col, distance)),
31                    Some((_, best_dist)) if distance < *best_dist => {
32                        best_match = Some((col, distance));
33                    }
34                    _ => {}
35                }
36            }
37        }
38
39        best_match.map(|(name, _)| name)
40    }
41
42    /// Calculate Levenshtein edit distance between two strings
43    fn edit_distance(&self, s1: &str, s2: &str) -> usize {
44        let len1 = s1.len();
45        let len2 = s2.len();
46        let mut matrix = vec![vec![0; len2 + 1]; len1 + 1];
47
48        for i in 0..=len1 {
49            matrix[i][0] = i;
50        }
51        for j in 0..=len2 {
52            matrix[0][j] = j;
53        }
54
55        for (i, c1) in s1.chars().enumerate() {
56            for (j, c2) in s2.chars().enumerate() {
57                let cost = if c1 == c2 { 0 } else { 1 };
58                matrix[i + 1][j + 1] = std::cmp::min(
59                    matrix[i][j + 1] + 1, // deletion
60                    std::cmp::min(
61                        matrix[i + 1][j] + 1, // insertion
62                        matrix[i][j] + cost,  // substitution
63                    ),
64                );
65            }
66        }
67
68        matrix[len1][len2]
69    }
70
71    /// Evaluate an SQL expression to produce a DataValue
72    pub fn evaluate(&self, expr: &SqlExpression, row_index: usize) -> Result<DataValue> {
73        debug!(
74            "ArithmeticEvaluator: evaluating {:?} for row {}",
75            expr, row_index
76        );
77
78        match expr {
79            SqlExpression::Column(column_name) => self.evaluate_column(column_name, row_index),
80            SqlExpression::StringLiteral(s) => Ok(DataValue::String(s.clone())),
81            SqlExpression::NumberLiteral(n) => self.evaluate_number_literal(n),
82            SqlExpression::BinaryOp { left, op, right } => {
83                self.evaluate_binary_op(left, op, right, row_index)
84            }
85            SqlExpression::FunctionCall { name, args } => {
86                self.evaluate_function(name, args, row_index)
87            }
88            SqlExpression::MethodCall {
89                object,
90                method,
91                args,
92            } => self.evaluate_method_call(object, method, args, row_index),
93            SqlExpression::ChainedMethodCall { base, method, args } => {
94                // Evaluate the base expression first, then apply the method
95                let base_value = self.evaluate(base, row_index)?;
96                self.evaluate_method_on_value(&base_value, method, args, row_index)
97            }
98            _ => Err(anyhow!(
99                "Unsupported expression type for arithmetic evaluation: {:?}",
100                expr
101            )),
102        }
103    }
104
105    /// Evaluate a column reference
106    fn evaluate_column(&self, column_name: &str, row_index: usize) -> Result<DataValue> {
107        let col_index = self.table.get_column_index(column_name).ok_or_else(|| {
108            let suggestion = self.find_similar_column(column_name);
109            match suggestion {
110                Some(similar) => anyhow!(
111                    "Column '{}' not found. Did you mean '{}'?",
112                    column_name,
113                    similar
114                ),
115                None => anyhow!("Column '{}' not found", column_name),
116            }
117        })?;
118
119        if row_index >= self.table.row_count() {
120            return Err(anyhow!("Row index {} out of bounds", row_index));
121        }
122
123        let row = self
124            .table
125            .get_row(row_index)
126            .ok_or_else(|| anyhow!("Row {} not found", row_index))?;
127
128        let value = row
129            .get(col_index)
130            .ok_or_else(|| anyhow!("Column index {} out of bounds for row", col_index))?;
131
132        Ok(value.clone())
133    }
134
135    /// Evaluate a number literal (handles both integers and floats)
136    fn evaluate_number_literal(&self, number_str: &str) -> Result<DataValue> {
137        // Try to parse as integer first
138        if let Ok(int_val) = number_str.parse::<i64>() {
139            return Ok(DataValue::Integer(int_val));
140        }
141
142        // If that fails, try as float
143        if let Ok(float_val) = number_str.parse::<f64>() {
144            return Ok(DataValue::Float(float_val));
145        }
146
147        Err(anyhow!("Invalid number literal: {}", number_str))
148    }
149
150    /// Evaluate a binary operation (arithmetic)
151    fn evaluate_binary_op(
152        &self,
153        left: &SqlExpression,
154        op: &str,
155        right: &SqlExpression,
156        row_index: usize,
157    ) -> Result<DataValue> {
158        let left_val = self.evaluate(left, row_index)?;
159        let right_val = self.evaluate(right, row_index)?;
160
161        debug!(
162            "ArithmeticEvaluator: {} {} {}",
163            self.format_value(&left_val),
164            op,
165            self.format_value(&right_val)
166        );
167
168        match op {
169            "+" => self.add_values(&left_val, &right_val),
170            "-" => self.subtract_values(&left_val, &right_val),
171            "*" => self.multiply_values(&left_val, &right_val),
172            "/" => self.divide_values(&left_val, &right_val),
173            _ => Err(anyhow!("Unsupported arithmetic operator: {}", op)),
174        }
175    }
176
177    /// Add two DataValues with type coercion
178    fn add_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
179        match (left, right) {
180            (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a + b)),
181            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 + b)),
182            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a + *b as f64)),
183            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a + b)),
184            _ => Err(anyhow!("Cannot add {:?} and {:?}", left, right)),
185        }
186    }
187
188    /// Subtract two DataValues with type coercion
189    fn subtract_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
190        match (left, right) {
191            (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a - b)),
192            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 - b)),
193            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a - *b as f64)),
194            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a - b)),
195            _ => Err(anyhow!("Cannot subtract {:?} and {:?}", left, right)),
196        }
197    }
198
199    /// Multiply two DataValues with type coercion
200    fn multiply_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
201        match (left, right) {
202            (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a * b)),
203            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 * b)),
204            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a * *b as f64)),
205            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a * b)),
206            _ => Err(anyhow!("Cannot multiply {:?} and {:?}", left, right)),
207        }
208    }
209
210    /// Divide two DataValues with type coercion
211    fn divide_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
212        // Check for division by zero first
213        let is_zero = match right {
214            DataValue::Integer(0) => true,
215            DataValue::Float(f) if f.abs() < f64::EPSILON => true,
216            _ => false,
217        };
218
219        if is_zero {
220            return Err(anyhow!("Division by zero"));
221        }
222
223        match (left, right) {
224            (DataValue::Integer(a), DataValue::Integer(b)) => {
225                // Integer division - if result is exact, keep as int, otherwise promote to float
226                if a % b == 0 {
227                    Ok(DataValue::Integer(a / b))
228                } else {
229                    Ok(DataValue::Float(*a as f64 / *b as f64))
230                }
231            }
232            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 / b)),
233            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a / *b as f64)),
234            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a / b)),
235            _ => Err(anyhow!("Cannot divide {:?} and {:?}", left, right)),
236        }
237    }
238
239    /// Format a DataValue for debug output
240    fn format_value(&self, value: &DataValue) -> String {
241        match value {
242            DataValue::Integer(i) => i.to_string(),
243            DataValue::Float(f) => f.to_string(),
244            DataValue::String(s) => format!("'{}'", s),
245            _ => format!("{:?}", value),
246        }
247    }
248
249    /// Evaluate a function call
250    fn evaluate_function(
251        &self,
252        name: &str,
253        args: &[SqlExpression],
254        row_index: usize,
255    ) -> Result<DataValue> {
256        match name {
257            "ROUND" => {
258                if args.is_empty() || args.len() > 2 {
259                    return Err(anyhow!("ROUND requires 1 or 2 arguments"));
260                }
261
262                // Evaluate the value to round
263                let value = self.evaluate(&args[0], row_index)?;
264
265                // Get decimal places (default to 0 if not specified)
266                let decimals = if args.len() == 2 {
267                    match self.evaluate(&args[1], row_index)? {
268                        DataValue::Integer(n) => n as i32,
269                        DataValue::Float(f) => f as i32,
270                        _ => return Err(anyhow!("ROUND precision must be a number")),
271                    }
272                } else {
273                    0
274                };
275
276                // Perform rounding
277                match value {
278                    DataValue::Integer(n) => Ok(DataValue::Integer(n)), // Already an integer
279                    DataValue::Float(f) => {
280                        if decimals >= 0 {
281                            let multiplier = 10_f64.powi(decimals);
282                            let rounded = (f * multiplier).round() / multiplier;
283                            if decimals == 0 {
284                                // Return as integer if rounding to 0 decimals
285                                Ok(DataValue::Integer(rounded as i64))
286                            } else {
287                                Ok(DataValue::Float(rounded))
288                            }
289                        } else {
290                            // Negative decimals round to left of decimal point
291                            let divisor = 10_f64.powi(-decimals);
292                            let rounded = (f / divisor).round() * divisor;
293                            Ok(DataValue::Float(rounded))
294                        }
295                    }
296                    _ => Err(anyhow!("ROUND can only be applied to numeric values")),
297                }
298            }
299            "ABS" => {
300                if args.len() != 1 {
301                    return Err(anyhow!("ABS requires exactly 1 argument"));
302                }
303
304                let value = self.evaluate(&args[0], row_index)?;
305                match value {
306                    DataValue::Integer(n) => Ok(DataValue::Integer(n.abs())),
307                    DataValue::Float(f) => Ok(DataValue::Float(f.abs())),
308                    _ => Err(anyhow!("ABS can only be applied to numeric values")),
309                }
310            }
311            "FLOOR" => {
312                if args.len() != 1 {
313                    return Err(anyhow!("FLOOR requires exactly 1 argument"));
314                }
315
316                let value = self.evaluate(&args[0], row_index)?;
317                match value {
318                    DataValue::Integer(n) => Ok(DataValue::Integer(n)),
319                    DataValue::Float(f) => Ok(DataValue::Integer(f.floor() as i64)),
320                    _ => Err(anyhow!("FLOOR can only be applied to numeric values")),
321                }
322            }
323            "CEILING" | "CEIL" => {
324                if args.len() != 1 {
325                    return Err(anyhow!("CEILING requires exactly 1 argument"));
326                }
327
328                let value = self.evaluate(&args[0], row_index)?;
329                match value {
330                    DataValue::Integer(n) => Ok(DataValue::Integer(n)),
331                    DataValue::Float(f) => Ok(DataValue::Integer(f.ceil() as i64)),
332                    _ => Err(anyhow!("CEILING can only be applied to numeric values")),
333                }
334            }
335            "MOD" => {
336                if args.len() != 2 {
337                    return Err(anyhow!("MOD requires exactly 2 arguments"));
338                }
339
340                let dividend = self.evaluate(&args[0], row_index)?;
341                let divisor = self.evaluate(&args[1], row_index)?;
342
343                match (&dividend, &divisor) {
344                    (DataValue::Integer(n), DataValue::Integer(d)) => {
345                        if *d == 0 {
346                            return Err(anyhow!("Division by zero in MOD"));
347                        }
348                        Ok(DataValue::Integer(n % d))
349                    }
350                    _ => {
351                        // Convert to float for mixed types
352                        let n = match dividend {
353                            DataValue::Integer(i) => i as f64,
354                            DataValue::Float(f) => f,
355                            _ => return Err(anyhow!("MOD requires numeric arguments")),
356                        };
357                        let d = match divisor {
358                            DataValue::Integer(i) => i as f64,
359                            DataValue::Float(f) => f,
360                            _ => return Err(anyhow!("MOD requires numeric arguments")),
361                        };
362                        if d == 0.0 {
363                            return Err(anyhow!("Division by zero in MOD"));
364                        }
365                        Ok(DataValue::Float(n % d))
366                    }
367                }
368            }
369            "QUOTIENT" => {
370                if args.len() != 2 {
371                    return Err(anyhow!("QUOTIENT requires exactly 2 arguments"));
372                }
373
374                let numerator = self.evaluate(&args[0], row_index)?;
375                let denominator = self.evaluate(&args[1], row_index)?;
376
377                match (&numerator, &denominator) {
378                    (DataValue::Integer(n), DataValue::Integer(d)) => {
379                        if *d == 0 {
380                            return Err(anyhow!("Division by zero in QUOTIENT"));
381                        }
382                        Ok(DataValue::Integer(n / d))
383                    }
384                    _ => {
385                        // Convert to float for mixed types
386                        let n = match numerator {
387                            DataValue::Integer(i) => i as f64,
388                            DataValue::Float(f) => f,
389                            _ => return Err(anyhow!("QUOTIENT requires numeric arguments")),
390                        };
391                        let d = match denominator {
392                            DataValue::Integer(i) => i as f64,
393                            DataValue::Float(f) => f,
394                            _ => return Err(anyhow!("QUOTIENT requires numeric arguments")),
395                        };
396                        if d == 0.0 {
397                            return Err(anyhow!("Division by zero in QUOTIENT"));
398                        }
399                        Ok(DataValue::Integer((n / d).trunc() as i64))
400                    }
401                }
402            }
403            "POWER" | "POW" => {
404                if args.len() != 2 {
405                    return Err(anyhow!("POWER requires exactly 2 arguments"));
406                }
407
408                let base = self.evaluate(&args[0], row_index)?;
409                let exponent = self.evaluate(&args[1], row_index)?;
410
411                match (&base, &exponent) {
412                    (DataValue::Integer(b), DataValue::Integer(e)) => {
413                        if *e >= 0 && *e <= i32::MAX as i64 {
414                            Ok(DataValue::Float((*b as f64).powi(*e as i32)))
415                        } else {
416                            Ok(DataValue::Float((*b as f64).powf(*e as f64)))
417                        }
418                    }
419                    _ => {
420                        // Convert to float for mixed types or floats
421                        let b = match base {
422                            DataValue::Integer(i) => i as f64,
423                            DataValue::Float(f) => f,
424                            _ => return Err(anyhow!("POWER requires numeric arguments")),
425                        };
426                        let e = match exponent {
427                            DataValue::Integer(i) => i as f64,
428                            DataValue::Float(f) => f,
429                            _ => return Err(anyhow!("POWER requires numeric arguments")),
430                        };
431                        Ok(DataValue::Float(b.powf(e)))
432                    }
433                }
434            }
435            "SQRT" => {
436                if args.len() != 1 {
437                    return Err(anyhow!("SQRT requires exactly 1 argument"));
438                }
439
440                let value = self.evaluate(&args[0], row_index)?;
441                match value {
442                    DataValue::Integer(n) => {
443                        if n < 0 {
444                            return Err(anyhow!("SQRT of negative number"));
445                        }
446                        Ok(DataValue::Float((n as f64).sqrt()))
447                    }
448                    DataValue::Float(f) => {
449                        if f < 0.0 {
450                            return Err(anyhow!("SQRT of negative number"));
451                        }
452                        Ok(DataValue::Float(f.sqrt()))
453                    }
454                    _ => Err(anyhow!("SQRT can only be applied to numeric values")),
455                }
456            }
457            "EXP" => {
458                if args.len() != 1 {
459                    return Err(anyhow!("EXP requires exactly 1 argument"));
460                }
461
462                let value = self.evaluate(&args[0], row_index)?;
463                match value {
464                    DataValue::Integer(n) => Ok(DataValue::Float((n as f64).exp())),
465                    DataValue::Float(f) => Ok(DataValue::Float(f.exp())),
466                    _ => Err(anyhow!("EXP can only be applied to numeric values")),
467                }
468            }
469            "LN" => {
470                if args.len() != 1 {
471                    return Err(anyhow!("LN requires exactly 1 argument"));
472                }
473
474                let value = self.evaluate(&args[0], row_index)?;
475                match value {
476                    DataValue::Integer(n) => {
477                        if n <= 0 {
478                            return Err(anyhow!("LN of non-positive number"));
479                        }
480                        Ok(DataValue::Float((n as f64).ln()))
481                    }
482                    DataValue::Float(f) => {
483                        if f <= 0.0 {
484                            return Err(anyhow!("LN of non-positive number"));
485                        }
486                        Ok(DataValue::Float(f.ln()))
487                    }
488                    _ => Err(anyhow!("LN can only be applied to numeric values")),
489                }
490            }
491            "LOG" | "LOG10" => {
492                if name == "LOG" && args.len() == 2 {
493                    // LOG with custom base
494                    let value = self.evaluate(&args[0], row_index)?;
495                    let base = self.evaluate(&args[1], row_index)?;
496
497                    let n = match value {
498                        DataValue::Integer(i) => i as f64,
499                        DataValue::Float(f) => f,
500                        _ => return Err(anyhow!("LOG requires numeric arguments")),
501                    };
502                    let b = match base {
503                        DataValue::Integer(i) => i as f64,
504                        DataValue::Float(f) => f,
505                        _ => return Err(anyhow!("LOG requires numeric arguments")),
506                    };
507
508                    if n <= 0.0 {
509                        return Err(anyhow!("LOG of non-positive number"));
510                    }
511                    if b <= 0.0 || b == 1.0 {
512                        return Err(anyhow!("Invalid LOG base"));
513                    }
514                    Ok(DataValue::Float(n.log(b)))
515                } else if (name == "LOG" && args.len() == 1) || name == "LOG10" {
516                    // LOG10 or LOG with default base 10
517                    if args.len() != 1 {
518                        return Err(anyhow!("{} requires exactly 1 argument", name));
519                    }
520
521                    let value = self.evaluate(&args[0], row_index)?;
522                    match value {
523                        DataValue::Integer(n) => {
524                            if n <= 0 {
525                                return Err(anyhow!("LOG10 of non-positive number"));
526                            }
527                            Ok(DataValue::Float((n as f64).log10()))
528                        }
529                        DataValue::Float(f) => {
530                            if f <= 0.0 {
531                                return Err(anyhow!("LOG10 of non-positive number"));
532                            }
533                            Ok(DataValue::Float(f.log10()))
534                        }
535                        _ => Err(anyhow!("LOG10 can only be applied to numeric values")),
536                    }
537                } else {
538                    Err(anyhow!("LOG requires 1 or 2 arguments"))
539                }
540            }
541            "PI" => {
542                if !args.is_empty() {
543                    return Err(anyhow!("PI takes no arguments"));
544                }
545                Ok(DataValue::Float(std::f64::consts::PI))
546            }
547            "DATEDIFF" => {
548                if args.len() != 3 {
549                    return Err(anyhow!(
550                        "DATEDIFF requires exactly 3 arguments: unit, date1, date2"
551                    ));
552                }
553
554                // First argument: unit (day, month, year, hour, minute, second)
555                let unit = match self.evaluate(&args[0], row_index)? {
556                    DataValue::String(s) => s.to_lowercase(),
557                    DataValue::InternedString(s) => s.to_lowercase(),
558                    _ => return Err(anyhow!("DATEDIFF unit must be a string")),
559                };
560
561                // Helper function to parse date/datetime strings
562                let parse_datetime = |value: DataValue| -> Result<DateTime<Utc>> {
563                    let parse_string = |s: &str| -> Result<DateTime<Utc>> {
564                        // Try various date/datetime formats
565
566                        // ISO formats (most common)
567                        if let Ok(dt) = NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S") {
568                            return Ok(Utc.from_utc_datetime(&dt));
569                        }
570                        if let Ok(dt) = NaiveDate::parse_from_str(s, "%Y-%m-%d") {
571                            return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
572                        }
573
574                        // US format: MM/DD/YYYY or MM-DD-YYYY
575                        if let Ok(dt) = NaiveDate::parse_from_str(s, "%m/%d/%Y") {
576                            return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
577                        }
578                        if let Ok(dt) = NaiveDate::parse_from_str(s, "%m-%d-%Y") {
579                            return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
580                        }
581
582                        // European format: DD/MM/YYYY or DD-MM-YYYY
583                        if let Ok(dt) = NaiveDate::parse_from_str(s, "%d/%m/%Y") {
584                            return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
585                        }
586                        if let Ok(dt) = NaiveDate::parse_from_str(s, "%d-%m-%Y") {
587                            return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
588                        }
589
590                        // Excel/Windows format: DD-MMM-YYYY (e.g., 15-Jan-2024)
591                        if let Ok(dt) = NaiveDate::parse_from_str(s, "%d-%b-%Y") {
592                            return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
593                        }
594
595                        // Full month names: January 15, 2024 or 15 January 2024
596                        if let Ok(dt) = NaiveDate::parse_from_str(s, "%B %d, %Y") {
597                            return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
598                        }
599                        if let Ok(dt) = NaiveDate::parse_from_str(s, "%d %B %Y") {
600                            return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
601                        }
602
603                        // With time: MM/DD/YYYY HH:MM:SS
604                        if let Ok(dt) = NaiveDateTime::parse_from_str(s, "%m/%d/%Y %H:%M:%S") {
605                            return Ok(Utc.from_utc_datetime(&dt));
606                        }
607                        if let Ok(dt) = NaiveDateTime::parse_from_str(s, "%d/%m/%Y %H:%M:%S") {
608                            return Ok(Utc.from_utc_datetime(&dt));
609                        }
610
611                        // ISO 8601 / RFC3339
612                        if let Ok(dt) = s.parse::<DateTime<Utc>>() {
613                            return Ok(dt);
614                        }
615
616                        Err(anyhow!("Could not parse date: {}. Supported formats: YYYY-MM-DD, MM/DD/YYYY, DD/MM/YYYY, DD-MMM-YYYY", s))
617                    };
618
619                    match value {
620                        DataValue::String(s) | DataValue::DateTime(s) => parse_string(&s),
621                        DataValue::InternedString(s) => parse_string(s.as_str()),
622                        _ => Err(anyhow!("DATEDIFF requires date/datetime values")),
623                    }
624                };
625
626                // Parse both dates
627                let date1 = parse_datetime(self.evaluate(&args[1], row_index)?)?;
628                let date2 = parse_datetime(self.evaluate(&args[2], row_index)?)?;
629
630                // Calculate difference based on unit
631                let diff = match unit.as_str() {
632                    "day" | "days" => {
633                        let duration = date2.signed_duration_since(date1);
634                        duration.num_days()
635                    }
636                    "month" | "months" => {
637                        // Approximate months as 30.44 days
638                        let duration = date2.signed_duration_since(date1);
639                        duration.num_days() / 30
640                    }
641                    "year" | "years" => {
642                        // Approximate years as 365.25 days
643                        let duration = date2.signed_duration_since(date1);
644                        duration.num_days() / 365
645                    }
646                    "hour" | "hours" => {
647                        let duration = date2.signed_duration_since(date1);
648                        duration.num_hours()
649                    }
650                    "minute" | "minutes" => {
651                        let duration = date2.signed_duration_since(date1);
652                        duration.num_minutes()
653                    }
654                    "second" | "seconds" => {
655                        let duration = date2.signed_duration_since(date1);
656                        duration.num_seconds()
657                    }
658                    _ => {
659                        return Err(anyhow!(
660                        "Unknown DATEDIFF unit: {}. Use: day, month, year, hour, minute, second",
661                        unit
662                    ))
663                    }
664                };
665
666                Ok(DataValue::Integer(diff))
667            }
668            "NOW" => {
669                if !args.is_empty() {
670                    return Err(anyhow!("NOW takes no arguments"));
671                }
672                let now = Utc::now();
673                Ok(DataValue::DateTime(
674                    now.format("%Y-%m-%d %H:%M:%S").to_string(),
675                ))
676            }
677            "TODAY" => {
678                if !args.is_empty() {
679                    return Err(anyhow!("TODAY takes no arguments"));
680                }
681                let today = Utc::now().date_naive();
682                Ok(DataValue::String(today.format("%Y-%m-%d").to_string()))
683            }
684            "DATEADD" => {
685                if args.len() != 3 {
686                    return Err(anyhow!(
687                        "DATEADD requires exactly 3 arguments: unit, number, date"
688                    ));
689                }
690
691                // First argument: unit (day, month, year, hour, minute, second)
692                let unit = match self.evaluate(&args[0], row_index)? {
693                    DataValue::String(s) => s.to_lowercase(),
694                    DataValue::InternedString(s) => s.to_lowercase(),
695                    _ => return Err(anyhow!("DATEADD unit must be a string")),
696                };
697
698                // Second argument: number to add (can be negative for subtraction)
699                let amount = match self.evaluate(&args[1], row_index)? {
700                    DataValue::Integer(i) => i,
701                    DataValue::Float(f) => f as i64,
702                    _ => return Err(anyhow!("DATEADD amount must be a number")),
703                };
704
705                // Third argument: base date
706                let base_date_value = self.evaluate(&args[2], row_index)?;
707
708                // Reuse the parse_datetime function from DATEDIFF
709                let parse_datetime = |value: DataValue| -> Result<DateTime<Utc>> {
710                    let parse_string = |s: &str| -> Result<DateTime<Utc>> {
711                        // Try various date/datetime formats (same as DATEDIFF)
712
713                        // ISO formats (most common)
714                        if let Ok(dt) = NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S") {
715                            return Ok(Utc.from_utc_datetime(&dt));
716                        }
717                        if let Ok(dt) = NaiveDate::parse_from_str(s, "%Y-%m-%d") {
718                            return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
719                        }
720
721                        // US format: MM/DD/YYYY or MM-DD-YYYY
722                        if let Ok(dt) = NaiveDate::parse_from_str(s, "%m/%d/%Y") {
723                            return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
724                        }
725                        if let Ok(dt) = NaiveDate::parse_from_str(s, "%m-%d-%Y") {
726                            return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
727                        }
728
729                        // European format: DD/MM/YYYY or DD-MM-YYYY
730                        if let Ok(dt) = NaiveDate::parse_from_str(s, "%d/%m/%Y") {
731                            return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
732                        }
733                        if let Ok(dt) = NaiveDate::parse_from_str(s, "%d-%m-%Y") {
734                            return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
735                        }
736
737                        // Excel/Windows format: DD-MMM-YYYY (e.g., 15-Jan-2024)
738                        if let Ok(dt) = NaiveDate::parse_from_str(s, "%d-%b-%Y") {
739                            return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
740                        }
741
742                        // Full month names: January 15, 2024 or 15 January 2024
743                        if let Ok(dt) = NaiveDate::parse_from_str(s, "%B %d, %Y") {
744                            return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
745                        }
746                        if let Ok(dt) = NaiveDate::parse_from_str(s, "%d %B %Y") {
747                            return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
748                        }
749
750                        // With time: MM/DD/YYYY HH:MM:SS
751                        if let Ok(dt) = NaiveDateTime::parse_from_str(s, "%m/%d/%Y %H:%M:%S") {
752                            return Ok(Utc.from_utc_datetime(&dt));
753                        }
754                        if let Ok(dt) = NaiveDateTime::parse_from_str(s, "%d/%m/%Y %H:%M:%S") {
755                            return Ok(Utc.from_utc_datetime(&dt));
756                        }
757
758                        // ISO 8601 / RFC3339
759                        if let Ok(dt) = s.parse::<DateTime<Utc>>() {
760                            return Ok(dt);
761                        }
762
763                        Err(anyhow!("Could not parse date: {}. Supported formats: YYYY-MM-DD, MM/DD/YYYY, DD/MM/YYYY, DD-MMM-YYYY", s))
764                    };
765
766                    match value {
767                        DataValue::String(s) | DataValue::DateTime(s) => parse_string(&s),
768                        DataValue::InternedString(s) => parse_string(s.as_str()),
769                        _ => Err(anyhow!("DATEADD requires date/datetime values")),
770                    }
771                };
772
773                // Parse the base date
774                let base_date = parse_datetime(base_date_value)?;
775
776                // Add the specified amount based on unit
777                let result_date = match unit.as_str() {
778                    "day" | "days" => base_date + chrono::Duration::days(amount),
779                    "month" | "months" => {
780                        // For months, we need to be careful about month boundaries
781                        let mut year = base_date.year();
782                        let mut month = base_date.month() as i32;
783                        let day = base_date.day();
784
785                        month += amount as i32;
786
787                        // Handle month overflow/underflow
788                        while month > 12 {
789                            month -= 12;
790                            year += 1;
791                        }
792                        while month < 1 {
793                            month += 12;
794                            year -= 1;
795                        }
796
797                        // Create new date, handling day overflow (e.g., Jan 31 + 1 month = Feb 28/29)
798                        let target_date = NaiveDate::from_ymd_opt(year, month as u32, day)
799                            .unwrap_or_else(|| {
800                                // If day doesn't exist in target month, use the last day of that month
801                                // Try decreasing days until we find a valid one
802                                for test_day in (1..=day).rev() {
803                                    if let Some(date) =
804                                        NaiveDate::from_ymd_opt(year, month as u32, test_day)
805                                    {
806                                        return date;
807                                    }
808                                }
809                                // This should never happen, but fallback to day 28 as safety
810                                NaiveDate::from_ymd_opt(year, month as u32, 28).unwrap()
811                            });
812
813                        Utc.from_utc_datetime(&target_date.and_time(base_date.time()))
814                    }
815                    "year" | "years" => {
816                        let new_year = base_date.year() + amount as i32;
817                        let target_date =
818                            NaiveDate::from_ymd_opt(new_year, base_date.month(), base_date.day())
819                                .unwrap_or_else(|| {
820                                    // Handle Feb 29 in non-leap years
821                                    NaiveDate::from_ymd_opt(new_year, base_date.month(), 28)
822                                        .unwrap()
823                                });
824                        Utc.from_utc_datetime(&target_date.and_time(base_date.time()))
825                    }
826                    "hour" | "hours" => base_date + chrono::Duration::hours(amount),
827                    "minute" | "minutes" => base_date + chrono::Duration::minutes(amount),
828                    "second" | "seconds" => base_date + chrono::Duration::seconds(amount),
829                    _ => {
830                        return Err(anyhow!(
831                            "Unknown DATEADD unit: {}. Use: day, month, year, hour, minute, second",
832                            unit
833                        ))
834                    }
835                };
836
837                // Return as datetime string
838                Ok(DataValue::DateTime(
839                    result_date.format("%Y-%m-%d %H:%M:%S").to_string(),
840                ))
841            }
842            "TEXTJOIN" => {
843                if args.len() < 3 {
844                    return Err(anyhow!("TEXTJOIN requires at least 3 arguments: delimiter, ignore_empty, text1, [text2, ...]"));
845                }
846
847                // First argument: delimiter
848                let delimiter = match self.evaluate(&args[0], row_index)? {
849                    DataValue::String(s) => s,
850                    DataValue::InternedString(s) => s.to_string(),
851                    DataValue::Integer(n) => n.to_string(),
852                    DataValue::Float(f) => f.to_string(),
853                    DataValue::Boolean(b) => b.to_string(),
854                    DataValue::Null => String::new(),
855                    _ => String::new(),
856                };
857
858                // Second argument: ignore_empty (treat as boolean - 0 is false, anything else is true)
859                let ignore_empty = match self.evaluate(&args[1], row_index)? {
860                    DataValue::Integer(n) => n != 0,
861                    DataValue::Float(f) => f != 0.0,
862                    DataValue::Boolean(b) => b,
863                    DataValue::String(s) => {
864                        !s.is_empty() && s != "0" && s.to_lowercase() != "false"
865                    }
866                    DataValue::InternedString(s) => {
867                        !s.is_empty() && s.as_str() != "0" && s.to_lowercase() != "false"
868                    }
869                    DataValue::Null => false,
870                    _ => true,
871                };
872
873                // Remaining arguments: values to join
874                let mut values = Vec::new();
875                for i in 2..args.len() {
876                    let value = self.evaluate(&args[i], row_index)?;
877                    let string_value = match value {
878                        DataValue::String(s) => Some(s),
879                        DataValue::InternedString(s) => Some(s.to_string()),
880                        DataValue::Integer(n) => Some(n.to_string()),
881                        DataValue::Float(f) => Some(f.to_string()),
882                        DataValue::Boolean(b) => Some(b.to_string()),
883                        DataValue::DateTime(dt) => Some(dt),
884                        DataValue::Null => {
885                            if ignore_empty {
886                                None
887                            } else {
888                                Some(String::new())
889                            }
890                        }
891                        _ => {
892                            if ignore_empty {
893                                None
894                            } else {
895                                Some(String::new())
896                            }
897                        }
898                    };
899
900                    if let Some(s) = string_value {
901                        if !ignore_empty || !s.is_empty() {
902                            values.push(s);
903                        }
904                    }
905                }
906
907                Ok(DataValue::String(values.join(&delimiter)))
908            }
909            _ => Err(anyhow!("Unknown function: {}", name)),
910        }
911    }
912
913    /// Evaluate a method call on a column (e.g., column.Trim())
914    fn evaluate_method_call(
915        &self,
916        object: &str,
917        method: &str,
918        args: &[SqlExpression],
919        row_index: usize,
920    ) -> Result<DataValue> {
921        // Get column value
922        let col_index = self.table.get_column_index(object).ok_or_else(|| {
923            let suggestion = self.find_similar_column(object);
924            match suggestion {
925                Some(similar) => {
926                    anyhow!("Column '{}' not found. Did you mean '{}'?", object, similar)
927                }
928                None => anyhow!("Column '{}' not found", object),
929            }
930        })?;
931
932        let cell_value = self.table.get_value(row_index, col_index).cloned();
933
934        self.evaluate_method_on_value(
935            &cell_value.unwrap_or(DataValue::Null),
936            method,
937            args,
938            row_index,
939        )
940    }
941
942    /// Evaluate a method on a value
943    fn evaluate_method_on_value(
944        &self,
945        value: &DataValue,
946        method: &str,
947        args: &[SqlExpression],
948        row_index: usize,
949    ) -> Result<DataValue> {
950        match method.to_lowercase().as_str() {
951            "trim" | "trimstart" | "trimend" => {
952                if !args.is_empty() {
953                    return Err(anyhow!("{} takes no arguments", method));
954                }
955
956                // Convert value to string and apply trim
957                let str_val = match value {
958                    DataValue::String(s) => s.clone(),
959                    DataValue::InternedString(s) => s.to_string(),
960                    DataValue::Integer(n) => n.to_string(),
961                    DataValue::Float(f) => f.to_string(),
962                    DataValue::Boolean(b) => b.to_string(),
963                    DataValue::DateTime(dt) => dt.clone(),
964                    DataValue::Null => return Ok(DataValue::Null),
965                };
966
967                let result = match method.to_lowercase().as_str() {
968                    "trim" => str_val.trim().to_string(),
969                    "trimstart" => str_val.trim_start().to_string(),
970                    "trimend" => str_val.trim_end().to_string(),
971                    _ => unreachable!(),
972                };
973
974                Ok(DataValue::String(result))
975            }
976            "length" => {
977                if !args.is_empty() {
978                    return Err(anyhow!("Length takes no arguments"));
979                }
980
981                // Get string length
982                let len = match value {
983                    DataValue::String(s) => s.len(),
984                    DataValue::InternedString(s) => s.len(),
985                    DataValue::Integer(n) => n.to_string().len(),
986                    DataValue::Float(f) => f.to_string().len(),
987                    DataValue::Boolean(b) => b.to_string().len(),
988                    DataValue::DateTime(dt) => dt.len(),
989                    DataValue::Null => return Ok(DataValue::Integer(0)),
990                };
991
992                Ok(DataValue::Integer(len as i64))
993            }
994            "indexof" => {
995                if args.len() != 1 {
996                    return Err(anyhow!("IndexOf requires exactly 1 argument"));
997                }
998
999                // Get the search string from args
1000                let search_str = match self.evaluate(&args[0], row_index)? {
1001                    DataValue::String(s) => s,
1002                    DataValue::InternedString(s) => s.to_string(),
1003                    DataValue::Integer(n) => n.to_string(),
1004                    DataValue::Float(f) => f.to_string(),
1005                    _ => return Err(anyhow!("IndexOf argument must be a string")),
1006                };
1007
1008                // Convert value to string and find index
1009                let str_val = match value {
1010                    DataValue::String(s) => s.clone(),
1011                    DataValue::InternedString(s) => s.to_string(),
1012                    DataValue::Integer(n) => n.to_string(),
1013                    DataValue::Float(f) => f.to_string(),
1014                    DataValue::Boolean(b) => b.to_string(),
1015                    DataValue::DateTime(dt) => dt.clone(),
1016                    DataValue::Null => return Ok(DataValue::Integer(-1)),
1017                };
1018
1019                let index = str_val.find(&search_str).map(|i| i as i64).unwrap_or(-1);
1020
1021                Ok(DataValue::Integer(index))
1022            }
1023            "contains" => {
1024                if args.len() != 1 {
1025                    return Err(anyhow!("Contains requires exactly 1 argument"));
1026                }
1027
1028                // Get the search string from args
1029                let search_str = match self.evaluate(&args[0], row_index)? {
1030                    DataValue::String(s) => s,
1031                    DataValue::InternedString(s) => s.to_string(),
1032                    DataValue::Integer(n) => n.to_string(),
1033                    DataValue::Float(f) => f.to_string(),
1034                    _ => return Err(anyhow!("Contains argument must be a string")),
1035                };
1036
1037                // Convert value to string and check contains
1038                let str_val = match value {
1039                    DataValue::String(s) => s.clone(),
1040                    DataValue::InternedString(s) => s.to_string(),
1041                    DataValue::Integer(n) => n.to_string(),
1042                    DataValue::Float(f) => f.to_string(),
1043                    DataValue::Boolean(b) => b.to_string(),
1044                    DataValue::DateTime(dt) => dt.clone(),
1045                    DataValue::Null => return Ok(DataValue::Boolean(false)),
1046                };
1047
1048                // Case-insensitive search
1049                let result = str_val.to_lowercase().contains(&search_str.to_lowercase());
1050                Ok(DataValue::Boolean(result))
1051            }
1052            "startswith" => {
1053                if args.len() != 1 {
1054                    return Err(anyhow!("StartsWith requires exactly 1 argument"));
1055                }
1056
1057                // Get the prefix from args
1058                let prefix = match self.evaluate(&args[0], row_index)? {
1059                    DataValue::String(s) => s,
1060                    DataValue::InternedString(s) => s.to_string(),
1061                    DataValue::Integer(n) => n.to_string(),
1062                    DataValue::Float(f) => f.to_string(),
1063                    _ => return Err(anyhow!("StartsWith argument must be a string")),
1064                };
1065
1066                // Convert value to string and check starts_with
1067                let str_val = match value {
1068                    DataValue::String(s) => s.clone(),
1069                    DataValue::InternedString(s) => s.to_string(),
1070                    DataValue::Integer(n) => n.to_string(),
1071                    DataValue::Float(f) => f.to_string(),
1072                    DataValue::Boolean(b) => b.to_string(),
1073                    DataValue::DateTime(dt) => dt.clone(),
1074                    DataValue::Null => return Ok(DataValue::Boolean(false)),
1075                };
1076
1077                // Case-insensitive check
1078                let result = str_val.to_lowercase().starts_with(&prefix.to_lowercase());
1079                Ok(DataValue::Boolean(result))
1080            }
1081            "endswith" => {
1082                if args.len() != 1 {
1083                    return Err(anyhow!("EndsWith requires exactly 1 argument"));
1084                }
1085
1086                // Get the suffix from args
1087                let suffix = match self.evaluate(&args[0], row_index)? {
1088                    DataValue::String(s) => s,
1089                    DataValue::InternedString(s) => s.to_string(),
1090                    DataValue::Integer(n) => n.to_string(),
1091                    DataValue::Float(f) => f.to_string(),
1092                    _ => return Err(anyhow!("EndsWith argument must be a string")),
1093                };
1094
1095                // Convert value to string and check ends_with
1096                let str_val = match value {
1097                    DataValue::String(s) => s.clone(),
1098                    DataValue::InternedString(s) => s.to_string(),
1099                    DataValue::Integer(n) => n.to_string(),
1100                    DataValue::Float(f) => f.to_string(),
1101                    DataValue::Boolean(b) => b.to_string(),
1102                    DataValue::DateTime(dt) => dt.clone(),
1103                    DataValue::Null => return Ok(DataValue::Boolean(false)),
1104                };
1105
1106                // Case-insensitive check
1107                let result = str_val.to_lowercase().ends_with(&suffix.to_lowercase());
1108                Ok(DataValue::Boolean(result))
1109            }
1110            _ => Err(anyhow!("Unsupported method: {}", method)),
1111        }
1112    }
1113}
1114
1115#[cfg(test)]
1116mod tests {
1117    use super::*;
1118    use crate::data::datatable::{DataColumn, DataRow};
1119
1120    fn create_test_table() -> DataTable {
1121        let mut table = DataTable::new("test");
1122        table.add_column(DataColumn::new("a"));
1123        table.add_column(DataColumn::new("b"));
1124        table.add_column(DataColumn::new("c"));
1125
1126        table
1127            .add_row(DataRow::new(vec![
1128                DataValue::Integer(10),
1129                DataValue::Float(2.5),
1130                DataValue::Integer(4),
1131            ]))
1132            .unwrap();
1133
1134        table
1135    }
1136
1137    #[test]
1138    fn test_evaluate_column() {
1139        let table = create_test_table();
1140        let evaluator = ArithmeticEvaluator::new(&table);
1141
1142        let expr = SqlExpression::Column("a".to_string());
1143        let result = evaluator.evaluate(&expr, 0).unwrap();
1144        assert_eq!(result, DataValue::Integer(10));
1145    }
1146
1147    #[test]
1148    fn test_evaluate_number_literal() {
1149        let table = create_test_table();
1150        let evaluator = ArithmeticEvaluator::new(&table);
1151
1152        let expr = SqlExpression::NumberLiteral("42".to_string());
1153        let result = evaluator.evaluate(&expr, 0).unwrap();
1154        assert_eq!(result, DataValue::Integer(42));
1155
1156        let expr = SqlExpression::NumberLiteral("3.14".to_string());
1157        let result = evaluator.evaluate(&expr, 0).unwrap();
1158        assert_eq!(result, DataValue::Float(3.14));
1159    }
1160
1161    #[test]
1162    fn test_add_values() {
1163        let table = create_test_table();
1164        let evaluator = ArithmeticEvaluator::new(&table);
1165
1166        // Integer + Integer
1167        let result = evaluator
1168            .add_values(&DataValue::Integer(5), &DataValue::Integer(3))
1169            .unwrap();
1170        assert_eq!(result, DataValue::Integer(8));
1171
1172        // Integer + Float
1173        let result = evaluator
1174            .add_values(&DataValue::Integer(5), &DataValue::Float(2.5))
1175            .unwrap();
1176        assert_eq!(result, DataValue::Float(7.5));
1177    }
1178
1179    #[test]
1180    fn test_multiply_values() {
1181        let table = create_test_table();
1182        let evaluator = ArithmeticEvaluator::new(&table);
1183
1184        // Integer * Float
1185        let result = evaluator
1186            .multiply_values(&DataValue::Integer(4), &DataValue::Float(2.5))
1187            .unwrap();
1188        assert_eq!(result, DataValue::Float(10.0));
1189    }
1190
1191    #[test]
1192    fn test_divide_values() {
1193        let table = create_test_table();
1194        let evaluator = ArithmeticEvaluator::new(&table);
1195
1196        // Exact division
1197        let result = evaluator
1198            .divide_values(&DataValue::Integer(10), &DataValue::Integer(2))
1199            .unwrap();
1200        assert_eq!(result, DataValue::Integer(5));
1201
1202        // Non-exact division
1203        let result = evaluator
1204            .divide_values(&DataValue::Integer(10), &DataValue::Integer(3))
1205            .unwrap();
1206        assert_eq!(result, DataValue::Float(10.0 / 3.0));
1207    }
1208
1209    #[test]
1210    fn test_division_by_zero() {
1211        let table = create_test_table();
1212        let evaluator = ArithmeticEvaluator::new(&table);
1213
1214        let result = evaluator.divide_values(&DataValue::Integer(10), &DataValue::Integer(0));
1215        assert!(result.is_err());
1216        assert!(result.unwrap_err().to_string().contains("Division by zero"));
1217    }
1218
1219    #[test]
1220    fn test_binary_op_expression() {
1221        let table = create_test_table();
1222        let evaluator = ArithmeticEvaluator::new(&table);
1223
1224        // a * b where a=10, b=2.5
1225        let expr = SqlExpression::BinaryOp {
1226            left: Box::new(SqlExpression::Column("a".to_string())),
1227            op: "*".to_string(),
1228            right: Box::new(SqlExpression::Column("b".to_string())),
1229        };
1230
1231        let result = evaluator.evaluate(&expr, 0).unwrap();
1232        assert_eq!(result, DataValue::Float(25.0));
1233    }
1234}