sql_cli/data/
arithmetic_evaluator.rs

1use crate::data::datatable::{DataTable, DataValue};
2use crate::sql::recursive_parser::SqlExpression;
3use anyhow::{anyhow, Result};
4use chrono::{DateTime, Datelike, NaiveDate, NaiveDateTime, TimeZone, Utc};
5use tracing::debug;
6
7/// Evaluates SQL expressions to compute DataValues (for SELECT clauses)
8/// This is different from RecursiveWhereEvaluator which returns boolean
9pub struct ArithmeticEvaluator<'a> {
10    table: &'a DataTable,
11}
12
13impl<'a> ArithmeticEvaluator<'a> {
14    pub fn new(table: &'a DataTable) -> Self {
15        Self { table }
16    }
17
18    /// Find a column name similar to the given name using edit distance
19    fn find_similar_column(&self, name: &str) -> Option<String> {
20        let columns = self.table.column_names();
21        let mut best_match: Option<(String, usize)> = None;
22
23        for col in columns {
24            let distance = self.edit_distance(&col.to_lowercase(), &name.to_lowercase());
25            // Only suggest if distance is small (likely a typo)
26            // Allow up to 3 edits for longer names
27            let max_distance = if name.len() > 10 { 3 } else { 2 };
28            if distance <= max_distance {
29                match &best_match {
30                    None => best_match = Some((col, distance)),
31                    Some((_, best_dist)) if distance < *best_dist => {
32                        best_match = Some((col, distance));
33                    }
34                    _ => {}
35                }
36            }
37        }
38
39        best_match.map(|(name, _)| name)
40    }
41
42    /// Calculate Levenshtein edit distance between two strings
43    fn edit_distance(&self, s1: &str, s2: &str) -> usize {
44        let len1 = s1.len();
45        let len2 = s2.len();
46        let mut matrix = vec![vec![0; len2 + 1]; len1 + 1];
47
48        for i in 0..=len1 {
49            matrix[i][0] = i;
50        }
51        for j in 0..=len2 {
52            matrix[0][j] = j;
53        }
54
55        for (i, c1) in s1.chars().enumerate() {
56            for (j, c2) in s2.chars().enumerate() {
57                let cost = if c1 == c2 { 0 } else { 1 };
58                matrix[i + 1][j + 1] = std::cmp::min(
59                    matrix[i][j + 1] + 1, // deletion
60                    std::cmp::min(
61                        matrix[i + 1][j] + 1, // insertion
62                        matrix[i][j] + cost,  // substitution
63                    ),
64                );
65            }
66        }
67
68        matrix[len1][len2]
69    }
70
71    /// Evaluate an SQL expression to produce a DataValue
72    pub fn evaluate(&self, expr: &SqlExpression, row_index: usize) -> Result<DataValue> {
73        debug!(
74            "ArithmeticEvaluator: evaluating {:?} for row {}",
75            expr, row_index
76        );
77
78        match expr {
79            SqlExpression::Column(column_name) => self.evaluate_column(column_name, row_index),
80            SqlExpression::StringLiteral(s) => Ok(DataValue::String(s.clone())),
81            SqlExpression::NumberLiteral(n) => self.evaluate_number_literal(n),
82            SqlExpression::BinaryOp { left, op, right } => {
83                self.evaluate_binary_op(left, op, right, row_index)
84            }
85            SqlExpression::FunctionCall { name, args } => {
86                self.evaluate_function(name, args, row_index)
87            }
88            SqlExpression::MethodCall {
89                object,
90                method,
91                args,
92            } => self.evaluate_method_call(object, method, args, row_index),
93            SqlExpression::ChainedMethodCall { base, method, args } => {
94                // Evaluate the base expression first, then apply the method
95                let base_value = self.evaluate(base, row_index)?;
96                self.evaluate_method_on_value(&base_value, method, args, row_index)
97            }
98            SqlExpression::CaseExpression {
99                when_branches,
100                else_branch,
101            } => self.evaluate_case_expression(when_branches, else_branch, row_index),
102            _ => Err(anyhow!(
103                "Unsupported expression type for arithmetic evaluation: {:?}",
104                expr
105            )),
106        }
107    }
108
109    /// Evaluate a column reference
110    fn evaluate_column(&self, column_name: &str, row_index: usize) -> Result<DataValue> {
111        let col_index = self.table.get_column_index(column_name).ok_or_else(|| {
112            let suggestion = self.find_similar_column(column_name);
113            match suggestion {
114                Some(similar) => anyhow!(
115                    "Column '{}' not found. Did you mean '{}'?",
116                    column_name,
117                    similar
118                ),
119                None => anyhow!("Column '{}' not found", column_name),
120            }
121        })?;
122
123        if row_index >= self.table.row_count() {
124            return Err(anyhow!("Row index {} out of bounds", row_index));
125        }
126
127        let row = self
128            .table
129            .get_row(row_index)
130            .ok_or_else(|| anyhow!("Row {} not found", row_index))?;
131
132        let value = row
133            .get(col_index)
134            .ok_or_else(|| anyhow!("Column index {} out of bounds for row", col_index))?;
135
136        Ok(value.clone())
137    }
138
139    /// Evaluate a number literal (handles both integers and floats)
140    fn evaluate_number_literal(&self, number_str: &str) -> Result<DataValue> {
141        // Try to parse as integer first
142        if let Ok(int_val) = number_str.parse::<i64>() {
143            return Ok(DataValue::Integer(int_val));
144        }
145
146        // If that fails, try as float
147        if let Ok(float_val) = number_str.parse::<f64>() {
148            return Ok(DataValue::Float(float_val));
149        }
150
151        Err(anyhow!("Invalid number literal: {}", number_str))
152    }
153
154    /// Evaluate a binary operation (arithmetic)
155    fn evaluate_binary_op(
156        &self,
157        left: &SqlExpression,
158        op: &str,
159        right: &SqlExpression,
160        row_index: usize,
161    ) -> Result<DataValue> {
162        let left_val = self.evaluate(left, row_index)?;
163        let right_val = self.evaluate(right, row_index)?;
164
165        debug!(
166            "ArithmeticEvaluator: {} {} {}",
167            self.format_value(&left_val),
168            op,
169            self.format_value(&right_val)
170        );
171
172        match op {
173            "+" => self.add_values(&left_val, &right_val),
174            "-" => self.subtract_values(&left_val, &right_val),
175            "*" => self.multiply_values(&left_val, &right_val),
176            "/" => self.divide_values(&left_val, &right_val),
177            // Comparison operators (return boolean results)
178            ">" => self.compare_values(&left_val, &right_val, |a, b| a > b),
179            "<" => self.compare_values(&left_val, &right_val, |a, b| a < b),
180            ">=" => self.compare_values(&left_val, &right_val, |a, b| a >= b),
181            "<=" => self.compare_values(&left_val, &right_val, |a, b| a <= b),
182            "=" => self.compare_values(&left_val, &right_val, |a, b| a == b),
183            "!=" | "<>" => self.compare_values(&left_val, &right_val, |a, b| a != b),
184            _ => Err(anyhow!("Unsupported arithmetic operator: {}", op)),
185        }
186    }
187
188    /// Add two DataValues with type coercion
189    fn add_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
190        match (left, right) {
191            (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a + b)),
192            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 + b)),
193            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a + *b as f64)),
194            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a + b)),
195            _ => Err(anyhow!("Cannot add {:?} and {:?}", left, right)),
196        }
197    }
198
199    /// Subtract two DataValues with type coercion
200    fn subtract_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
201        match (left, right) {
202            (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a - b)),
203            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 - b)),
204            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a - *b as f64)),
205            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a - b)),
206            _ => Err(anyhow!("Cannot subtract {:?} and {:?}", left, right)),
207        }
208    }
209
210    /// Multiply two DataValues with type coercion
211    fn multiply_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
212        match (left, right) {
213            (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a * b)),
214            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 * b)),
215            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a * *b as f64)),
216            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a * b)),
217            _ => Err(anyhow!("Cannot multiply {:?} and {:?}", left, right)),
218        }
219    }
220
221    /// Divide two DataValues with type coercion
222    fn divide_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
223        // Check for division by zero first
224        let is_zero = match right {
225            DataValue::Integer(0) => true,
226            DataValue::Float(f) if f.abs() < f64::EPSILON => true,
227            _ => false,
228        };
229
230        if is_zero {
231            return Err(anyhow!("Division by zero"));
232        }
233
234        match (left, right) {
235            (DataValue::Integer(a), DataValue::Integer(b)) => {
236                // Integer division - if result is exact, keep as int, otherwise promote to float
237                if a % b == 0 {
238                    Ok(DataValue::Integer(a / b))
239                } else {
240                    Ok(DataValue::Float(*a as f64 / *b as f64))
241                }
242            }
243            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 / b)),
244            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a / *b as f64)),
245            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a / b)),
246            _ => Err(anyhow!("Cannot divide {:?} and {:?}", left, right)),
247        }
248    }
249
250    /// Format a DataValue for debug output
251    fn format_value(&self, value: &DataValue) -> String {
252        match value {
253            DataValue::Integer(i) => i.to_string(),
254            DataValue::Float(f) => f.to_string(),
255            DataValue::String(s) => format!("'{}'", s),
256            _ => format!("{:?}", value),
257        }
258    }
259
260    /// Compare two DataValues using the provided comparison function
261    fn compare_values<F>(&self, left: &DataValue, right: &DataValue, op: F) -> Result<DataValue>
262    where
263        F: Fn(f64, f64) -> bool,
264    {
265        debug!(
266            "ArithmeticEvaluator: comparing values {:?} and {:?}",
267            left, right
268        );
269
270        let result = match (left, right) {
271            // Integer comparisons
272            (DataValue::Integer(a), DataValue::Integer(b)) => op(*a as f64, *b as f64),
273            (DataValue::Integer(a), DataValue::Float(b)) => op(*a as f64, *b),
274            (DataValue::Float(a), DataValue::Integer(b)) => op(*a, *b as f64),
275            (DataValue::Float(a), DataValue::Float(b)) => op(*a, *b),
276
277            // String comparisons (lexicographic)
278            (DataValue::String(a), DataValue::String(b)) => {
279                let a_num = a.parse::<f64>();
280                let b_num = b.parse::<f64>();
281                match (a_num, b_num) {
282                    (Ok(a_val), Ok(b_val)) => op(a_val, b_val), // Both are numbers
283                    _ => op(a.len() as f64, b.len() as f64),    // Fallback to length comparison
284                }
285            }
286            (DataValue::InternedString(a), DataValue::InternedString(b)) => {
287                let a_num = a.parse::<f64>();
288                let b_num = b.parse::<f64>();
289                match (a_num, b_num) {
290                    (Ok(a_val), Ok(b_val)) => op(a_val, b_val), // Both are numbers
291                    _ => op(a.len() as f64, b.len() as f64),    // Fallback to length comparison
292                }
293            }
294            (DataValue::String(a), DataValue::InternedString(b)) => {
295                let a_num = a.parse::<f64>();
296                let b_num = b.parse::<f64>();
297                match (a_num, b_num) {
298                    (Ok(a_val), Ok(b_val)) => op(a_val, b_val), // Both are numbers
299                    _ => op(a.len() as f64, b.len() as f64),    // Fallback to length comparison
300                }
301            }
302            (DataValue::InternedString(a), DataValue::String(b)) => {
303                let a_num = a.parse::<f64>();
304                let b_num = b.parse::<f64>();
305                match (a_num, b_num) {
306                    (Ok(a_val), Ok(b_val)) => op(a_val, b_val), // Both are numbers
307                    _ => op(a.len() as f64, b.len() as f64),    // Fallback to length comparison
308                }
309            }
310
311            // Mixed type comparisons (try to convert to numbers)
312            (DataValue::String(a), DataValue::Integer(b)) => {
313                match a.parse::<f64>() {
314                    Ok(a_val) => op(a_val, *b as f64),
315                    Err(_) => false, // String can't be compared with number
316                }
317            }
318            (DataValue::Integer(a), DataValue::String(b)) => {
319                match b.parse::<f64>() {
320                    Ok(b_val) => op(*a as f64, b_val),
321                    Err(_) => false, // String can't be compared with number
322                }
323            }
324            (DataValue::String(a), DataValue::Float(b)) => match a.parse::<f64>() {
325                Ok(a_val) => op(a_val, *b),
326                Err(_) => false,
327            },
328            (DataValue::Float(a), DataValue::String(b)) => match b.parse::<f64>() {
329                Ok(b_val) => op(*a, b_val),
330                Err(_) => false,
331            },
332
333            // NULL comparisons
334            (DataValue::Null, _) | (_, DataValue::Null) => false,
335
336            // Boolean comparisons
337            (DataValue::Boolean(a), DataValue::Boolean(b)) => {
338                op(if *a { 1.0 } else { 0.0 }, if *b { 1.0 } else { 0.0 })
339            }
340
341            _ => {
342                debug!(
343                    "ArithmeticEvaluator: unsupported comparison between {:?} and {:?}",
344                    left, right
345                );
346                false
347            }
348        };
349
350        debug!("ArithmeticEvaluator: comparison result: {}", result);
351        Ok(DataValue::Boolean(result))
352    }
353
354    /// Evaluate a function call
355    fn evaluate_function(
356        &self,
357        name: &str,
358        args: &[SqlExpression],
359        row_index: usize,
360    ) -> Result<DataValue> {
361        match name {
362            "ROUND" => {
363                if args.is_empty() || args.len() > 2 {
364                    return Err(anyhow!("ROUND requires 1 or 2 arguments"));
365                }
366
367                // Evaluate the value to round
368                let value = self.evaluate(&args[0], row_index)?;
369
370                // Get decimal places (default to 0 if not specified)
371                let decimals = if args.len() == 2 {
372                    match self.evaluate(&args[1], row_index)? {
373                        DataValue::Integer(n) => n as i32,
374                        DataValue::Float(f) => f as i32,
375                        _ => return Err(anyhow!("ROUND precision must be a number")),
376                    }
377                } else {
378                    0
379                };
380
381                // Perform rounding
382                match value {
383                    DataValue::Integer(n) => Ok(DataValue::Integer(n)), // Already an integer
384                    DataValue::Float(f) => {
385                        if decimals >= 0 {
386                            let multiplier = 10_f64.powi(decimals);
387                            let rounded = (f * multiplier).round() / multiplier;
388                            if decimals == 0 {
389                                // Return as integer if rounding to 0 decimals
390                                Ok(DataValue::Integer(rounded as i64))
391                            } else {
392                                Ok(DataValue::Float(rounded))
393                            }
394                        } else {
395                            // Negative decimals round to left of decimal point
396                            let divisor = 10_f64.powi(-decimals);
397                            let rounded = (f / divisor).round() * divisor;
398                            Ok(DataValue::Float(rounded))
399                        }
400                    }
401                    _ => Err(anyhow!("ROUND can only be applied to numeric values")),
402                }
403            }
404            "ABS" => {
405                if args.len() != 1 {
406                    return Err(anyhow!("ABS requires exactly 1 argument"));
407                }
408
409                let value = self.evaluate(&args[0], row_index)?;
410                match value {
411                    DataValue::Integer(n) => Ok(DataValue::Integer(n.abs())),
412                    DataValue::Float(f) => Ok(DataValue::Float(f.abs())),
413                    _ => Err(anyhow!("ABS can only be applied to numeric values")),
414                }
415            }
416            "FLOOR" => {
417                if args.len() != 1 {
418                    return Err(anyhow!("FLOOR requires exactly 1 argument"));
419                }
420
421                let value = self.evaluate(&args[0], row_index)?;
422                match value {
423                    DataValue::Integer(n) => Ok(DataValue::Integer(n)),
424                    DataValue::Float(f) => Ok(DataValue::Integer(f.floor() as i64)),
425                    _ => Err(anyhow!("FLOOR can only be applied to numeric values")),
426                }
427            }
428            "CEILING" | "CEIL" => {
429                if args.len() != 1 {
430                    return Err(anyhow!("CEILING requires exactly 1 argument"));
431                }
432
433                let value = self.evaluate(&args[0], row_index)?;
434                match value {
435                    DataValue::Integer(n) => Ok(DataValue::Integer(n)),
436                    DataValue::Float(f) => Ok(DataValue::Integer(f.ceil() as i64)),
437                    _ => Err(anyhow!("CEILING can only be applied to numeric values")),
438                }
439            }
440            "MOD" => {
441                if args.len() != 2 {
442                    return Err(anyhow!("MOD requires exactly 2 arguments"));
443                }
444
445                let dividend = self.evaluate(&args[0], row_index)?;
446                let divisor = self.evaluate(&args[1], row_index)?;
447
448                match (&dividend, &divisor) {
449                    (DataValue::Integer(n), DataValue::Integer(d)) => {
450                        if *d == 0 {
451                            return Err(anyhow!("Division by zero in MOD"));
452                        }
453                        Ok(DataValue::Integer(n % d))
454                    }
455                    _ => {
456                        // Convert to float for mixed types
457                        let n = match dividend {
458                            DataValue::Integer(i) => i as f64,
459                            DataValue::Float(f) => f,
460                            _ => return Err(anyhow!("MOD requires numeric arguments")),
461                        };
462                        let d = match divisor {
463                            DataValue::Integer(i) => i as f64,
464                            DataValue::Float(f) => f,
465                            _ => return Err(anyhow!("MOD requires numeric arguments")),
466                        };
467                        if d == 0.0 {
468                            return Err(anyhow!("Division by zero in MOD"));
469                        }
470                        Ok(DataValue::Float(n % d))
471                    }
472                }
473            }
474            "QUOTIENT" => {
475                if args.len() != 2 {
476                    return Err(anyhow!("QUOTIENT requires exactly 2 arguments"));
477                }
478
479                let numerator = self.evaluate(&args[0], row_index)?;
480                let denominator = self.evaluate(&args[1], row_index)?;
481
482                match (&numerator, &denominator) {
483                    (DataValue::Integer(n), DataValue::Integer(d)) => {
484                        if *d == 0 {
485                            return Err(anyhow!("Division by zero in QUOTIENT"));
486                        }
487                        Ok(DataValue::Integer(n / d))
488                    }
489                    _ => {
490                        // Convert to float for mixed types
491                        let n = match numerator {
492                            DataValue::Integer(i) => i as f64,
493                            DataValue::Float(f) => f,
494                            _ => return Err(anyhow!("QUOTIENT requires numeric arguments")),
495                        };
496                        let d = match denominator {
497                            DataValue::Integer(i) => i as f64,
498                            DataValue::Float(f) => f,
499                            _ => return Err(anyhow!("QUOTIENT requires numeric arguments")),
500                        };
501                        if d == 0.0 {
502                            return Err(anyhow!("Division by zero in QUOTIENT"));
503                        }
504                        Ok(DataValue::Integer((n / d).trunc() as i64))
505                    }
506                }
507            }
508            "POWER" | "POW" => {
509                if args.len() != 2 {
510                    return Err(anyhow!("POWER requires exactly 2 arguments"));
511                }
512
513                let base = self.evaluate(&args[0], row_index)?;
514                let exponent = self.evaluate(&args[1], row_index)?;
515
516                match (&base, &exponent) {
517                    (DataValue::Integer(b), DataValue::Integer(e)) => {
518                        if *e >= 0 && *e <= i32::MAX as i64 {
519                            Ok(DataValue::Float((*b as f64).powi(*e as i32)))
520                        } else {
521                            Ok(DataValue::Float((*b as f64).powf(*e as f64)))
522                        }
523                    }
524                    _ => {
525                        // Convert to float for mixed types or floats
526                        let b = match base {
527                            DataValue::Integer(i) => i as f64,
528                            DataValue::Float(f) => f,
529                            _ => return Err(anyhow!("POWER requires numeric arguments")),
530                        };
531                        let e = match exponent {
532                            DataValue::Integer(i) => i as f64,
533                            DataValue::Float(f) => f,
534                            _ => return Err(anyhow!("POWER requires numeric arguments")),
535                        };
536                        Ok(DataValue::Float(b.powf(e)))
537                    }
538                }
539            }
540            "SQRT" => {
541                if args.len() != 1 {
542                    return Err(anyhow!("SQRT requires exactly 1 argument"));
543                }
544
545                let value = self.evaluate(&args[0], row_index)?;
546                match value {
547                    DataValue::Integer(n) => {
548                        if n < 0 {
549                            return Err(anyhow!("SQRT of negative number"));
550                        }
551                        Ok(DataValue::Float((n as f64).sqrt()))
552                    }
553                    DataValue::Float(f) => {
554                        if f < 0.0 {
555                            return Err(anyhow!("SQRT of negative number"));
556                        }
557                        Ok(DataValue::Float(f.sqrt()))
558                    }
559                    _ => Err(anyhow!("SQRT can only be applied to numeric values")),
560                }
561            }
562            "EXP" => {
563                if args.len() != 1 {
564                    return Err(anyhow!("EXP requires exactly 1 argument"));
565                }
566
567                let value = self.evaluate(&args[0], row_index)?;
568                match value {
569                    DataValue::Integer(n) => Ok(DataValue::Float((n as f64).exp())),
570                    DataValue::Float(f) => Ok(DataValue::Float(f.exp())),
571                    _ => Err(anyhow!("EXP can only be applied to numeric values")),
572                }
573            }
574            "LN" => {
575                if args.len() != 1 {
576                    return Err(anyhow!("LN requires exactly 1 argument"));
577                }
578
579                let value = self.evaluate(&args[0], row_index)?;
580                match value {
581                    DataValue::Integer(n) => {
582                        if n <= 0 {
583                            return Err(anyhow!("LN of non-positive number"));
584                        }
585                        Ok(DataValue::Float((n as f64).ln()))
586                    }
587                    DataValue::Float(f) => {
588                        if f <= 0.0 {
589                            return Err(anyhow!("LN of non-positive number"));
590                        }
591                        Ok(DataValue::Float(f.ln()))
592                    }
593                    _ => Err(anyhow!("LN can only be applied to numeric values")),
594                }
595            }
596            "LOG" | "LOG10" => {
597                if name == "LOG" && args.len() == 2 {
598                    // LOG with custom base
599                    let value = self.evaluate(&args[0], row_index)?;
600                    let base = self.evaluate(&args[1], row_index)?;
601
602                    let n = match value {
603                        DataValue::Integer(i) => i as f64,
604                        DataValue::Float(f) => f,
605                        _ => return Err(anyhow!("LOG requires numeric arguments")),
606                    };
607                    let b = match base {
608                        DataValue::Integer(i) => i as f64,
609                        DataValue::Float(f) => f,
610                        _ => return Err(anyhow!("LOG requires numeric arguments")),
611                    };
612
613                    if n <= 0.0 {
614                        return Err(anyhow!("LOG of non-positive number"));
615                    }
616                    if b <= 0.0 || b == 1.0 {
617                        return Err(anyhow!("Invalid LOG base"));
618                    }
619                    Ok(DataValue::Float(n.log(b)))
620                } else if (name == "LOG" && args.len() == 1) || name == "LOG10" {
621                    // LOG10 or LOG with default base 10
622                    if args.len() != 1 {
623                        return Err(anyhow!("{} requires exactly 1 argument", name));
624                    }
625
626                    let value = self.evaluate(&args[0], row_index)?;
627                    match value {
628                        DataValue::Integer(n) => {
629                            if n <= 0 {
630                                return Err(anyhow!("LOG10 of non-positive number"));
631                            }
632                            Ok(DataValue::Float((n as f64).log10()))
633                        }
634                        DataValue::Float(f) => {
635                            if f <= 0.0 {
636                                return Err(anyhow!("LOG10 of non-positive number"));
637                            }
638                            Ok(DataValue::Float(f.log10()))
639                        }
640                        _ => Err(anyhow!("LOG10 can only be applied to numeric values")),
641                    }
642                } else {
643                    Err(anyhow!("LOG requires 1 or 2 arguments"))
644                }
645            }
646            "PI" => {
647                if !args.is_empty() {
648                    return Err(anyhow!("PI takes no arguments"));
649                }
650                Ok(DataValue::Float(std::f64::consts::PI))
651            }
652            "DATEDIFF" => {
653                if args.len() != 3 {
654                    return Err(anyhow!(
655                        "DATEDIFF requires exactly 3 arguments: unit, date1, date2"
656                    ));
657                }
658
659                // First argument: unit (day, month, year, hour, minute, second)
660                let unit = match self.evaluate(&args[0], row_index)? {
661                    DataValue::String(s) => s.to_lowercase(),
662                    DataValue::InternedString(s) => s.to_lowercase(),
663                    _ => return Err(anyhow!("DATEDIFF unit must be a string")),
664                };
665
666                // Helper function to parse date/datetime strings
667                let parse_datetime = |value: DataValue| -> Result<DateTime<Utc>> {
668                    let parse_string = |s: &str| -> Result<DateTime<Utc>> {
669                        // Try various date/datetime formats
670
671                        // ISO formats (most common)
672                        if let Ok(dt) = NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S") {
673                            return Ok(Utc.from_utc_datetime(&dt));
674                        }
675                        if let Ok(dt) = NaiveDate::parse_from_str(s, "%Y-%m-%d") {
676                            return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
677                        }
678
679                        // US format: MM/DD/YYYY or MM-DD-YYYY
680                        if let Ok(dt) = NaiveDate::parse_from_str(s, "%m/%d/%Y") {
681                            return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
682                        }
683                        if let Ok(dt) = NaiveDate::parse_from_str(s, "%m-%d-%Y") {
684                            return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
685                        }
686
687                        // European format: DD/MM/YYYY or DD-MM-YYYY
688                        if let Ok(dt) = NaiveDate::parse_from_str(s, "%d/%m/%Y") {
689                            return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
690                        }
691                        if let Ok(dt) = NaiveDate::parse_from_str(s, "%d-%m-%Y") {
692                            return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
693                        }
694
695                        // Excel/Windows format: DD-MMM-YYYY (e.g., 15-Jan-2024)
696                        if let Ok(dt) = NaiveDate::parse_from_str(s, "%d-%b-%Y") {
697                            return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
698                        }
699
700                        // Full month names: January 15, 2024 or 15 January 2024
701                        if let Ok(dt) = NaiveDate::parse_from_str(s, "%B %d, %Y") {
702                            return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
703                        }
704                        if let Ok(dt) = NaiveDate::parse_from_str(s, "%d %B %Y") {
705                            return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
706                        }
707
708                        // With time: MM/DD/YYYY HH:MM:SS
709                        if let Ok(dt) = NaiveDateTime::parse_from_str(s, "%m/%d/%Y %H:%M:%S") {
710                            return Ok(Utc.from_utc_datetime(&dt));
711                        }
712                        if let Ok(dt) = NaiveDateTime::parse_from_str(s, "%d/%m/%Y %H:%M:%S") {
713                            return Ok(Utc.from_utc_datetime(&dt));
714                        }
715
716                        // ISO 8601 / RFC3339
717                        if let Ok(dt) = s.parse::<DateTime<Utc>>() {
718                            return Ok(dt);
719                        }
720
721                        Err(anyhow!("Could not parse date: {}. Supported formats: YYYY-MM-DD, MM/DD/YYYY, DD/MM/YYYY, DD-MMM-YYYY", s))
722                    };
723
724                    match value {
725                        DataValue::String(s) | DataValue::DateTime(s) => parse_string(&s),
726                        DataValue::InternedString(s) => parse_string(s.as_str()),
727                        _ => Err(anyhow!("DATEDIFF requires date/datetime values")),
728                    }
729                };
730
731                // Parse both dates
732                let date1 = parse_datetime(self.evaluate(&args[1], row_index)?)?;
733                let date2 = parse_datetime(self.evaluate(&args[2], row_index)?)?;
734
735                // Calculate difference based on unit
736                let diff = match unit.as_str() {
737                    "day" | "days" => {
738                        let duration = date2.signed_duration_since(date1);
739                        duration.num_days()
740                    }
741                    "month" | "months" => {
742                        // Approximate months as 30.44 days
743                        let duration = date2.signed_duration_since(date1);
744                        duration.num_days() / 30
745                    }
746                    "year" | "years" => {
747                        // Approximate years as 365.25 days
748                        let duration = date2.signed_duration_since(date1);
749                        duration.num_days() / 365
750                    }
751                    "hour" | "hours" => {
752                        let duration = date2.signed_duration_since(date1);
753                        duration.num_hours()
754                    }
755                    "minute" | "minutes" => {
756                        let duration = date2.signed_duration_since(date1);
757                        duration.num_minutes()
758                    }
759                    "second" | "seconds" => {
760                        let duration = date2.signed_duration_since(date1);
761                        duration.num_seconds()
762                    }
763                    _ => {
764                        return Err(anyhow!(
765                        "Unknown DATEDIFF unit: {}. Use: day, month, year, hour, minute, second",
766                        unit
767                    ))
768                    }
769                };
770
771                Ok(DataValue::Integer(diff))
772            }
773            "NOW" => {
774                if !args.is_empty() {
775                    return Err(anyhow!("NOW takes no arguments"));
776                }
777                let now = Utc::now();
778                Ok(DataValue::DateTime(
779                    now.format("%Y-%m-%d %H:%M:%S").to_string(),
780                ))
781            }
782            "TODAY" => {
783                if !args.is_empty() {
784                    return Err(anyhow!("TODAY takes no arguments"));
785                }
786                let today = Utc::now().date_naive();
787                Ok(DataValue::String(today.format("%Y-%m-%d").to_string()))
788            }
789            "DATEADD" => {
790                if args.len() != 3 {
791                    return Err(anyhow!(
792                        "DATEADD requires exactly 3 arguments: unit, number, date"
793                    ));
794                }
795
796                // First argument: unit (day, month, year, hour, minute, second)
797                let unit = match self.evaluate(&args[0], row_index)? {
798                    DataValue::String(s) => s.to_lowercase(),
799                    DataValue::InternedString(s) => s.to_lowercase(),
800                    _ => return Err(anyhow!("DATEADD unit must be a string")),
801                };
802
803                // Second argument: number to add (can be negative for subtraction)
804                let amount = match self.evaluate(&args[1], row_index)? {
805                    DataValue::Integer(i) => i,
806                    DataValue::Float(f) => f as i64,
807                    _ => return Err(anyhow!("DATEADD amount must be a number")),
808                };
809
810                // Third argument: base date
811                let base_date_value = self.evaluate(&args[2], row_index)?;
812
813                // Reuse the parse_datetime function from DATEDIFF
814                let parse_datetime = |value: DataValue| -> Result<DateTime<Utc>> {
815                    let parse_string = |s: &str| -> Result<DateTime<Utc>> {
816                        // Try various date/datetime formats (same as DATEDIFF)
817
818                        // ISO formats (most common)
819                        if let Ok(dt) = NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S") {
820                            return Ok(Utc.from_utc_datetime(&dt));
821                        }
822                        if let Ok(dt) = NaiveDate::parse_from_str(s, "%Y-%m-%d") {
823                            return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
824                        }
825
826                        // US format: MM/DD/YYYY or MM-DD-YYYY
827                        if let Ok(dt) = NaiveDate::parse_from_str(s, "%m/%d/%Y") {
828                            return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
829                        }
830                        if let Ok(dt) = NaiveDate::parse_from_str(s, "%m-%d-%Y") {
831                            return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
832                        }
833
834                        // European format: DD/MM/YYYY or DD-MM-YYYY
835                        if let Ok(dt) = NaiveDate::parse_from_str(s, "%d/%m/%Y") {
836                            return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
837                        }
838                        if let Ok(dt) = NaiveDate::parse_from_str(s, "%d-%m-%Y") {
839                            return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
840                        }
841
842                        // Excel/Windows format: DD-MMM-YYYY (e.g., 15-Jan-2024)
843                        if let Ok(dt) = NaiveDate::parse_from_str(s, "%d-%b-%Y") {
844                            return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
845                        }
846
847                        // Full month names: January 15, 2024 or 15 January 2024
848                        if let Ok(dt) = NaiveDate::parse_from_str(s, "%B %d, %Y") {
849                            return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
850                        }
851                        if let Ok(dt) = NaiveDate::parse_from_str(s, "%d %B %Y") {
852                            return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
853                        }
854
855                        // With time: MM/DD/YYYY HH:MM:SS
856                        if let Ok(dt) = NaiveDateTime::parse_from_str(s, "%m/%d/%Y %H:%M:%S") {
857                            return Ok(Utc.from_utc_datetime(&dt));
858                        }
859                        if let Ok(dt) = NaiveDateTime::parse_from_str(s, "%d/%m/%Y %H:%M:%S") {
860                            return Ok(Utc.from_utc_datetime(&dt));
861                        }
862
863                        // ISO 8601 / RFC3339
864                        if let Ok(dt) = s.parse::<DateTime<Utc>>() {
865                            return Ok(dt);
866                        }
867
868                        Err(anyhow!("Could not parse date: {}. Supported formats: YYYY-MM-DD, MM/DD/YYYY, DD/MM/YYYY, DD-MMM-YYYY", s))
869                    };
870
871                    match value {
872                        DataValue::String(s) | DataValue::DateTime(s) => parse_string(&s),
873                        DataValue::InternedString(s) => parse_string(s.as_str()),
874                        _ => Err(anyhow!("DATEADD requires date/datetime values")),
875                    }
876                };
877
878                // Parse the base date
879                let base_date = parse_datetime(base_date_value)?;
880
881                // Add the specified amount based on unit
882                let result_date = match unit.as_str() {
883                    "day" | "days" => base_date + chrono::Duration::days(amount),
884                    "month" | "months" => {
885                        // For months, we need to be careful about month boundaries
886                        let mut year = base_date.year();
887                        let mut month = base_date.month() as i32;
888                        let day = base_date.day();
889
890                        month += amount as i32;
891
892                        // Handle month overflow/underflow
893                        while month > 12 {
894                            month -= 12;
895                            year += 1;
896                        }
897                        while month < 1 {
898                            month += 12;
899                            year -= 1;
900                        }
901
902                        // Create new date, handling day overflow (e.g., Jan 31 + 1 month = Feb 28/29)
903                        let target_date = NaiveDate::from_ymd_opt(year, month as u32, day)
904                            .unwrap_or_else(|| {
905                                // If day doesn't exist in target month, use the last day of that month
906                                // Try decreasing days until we find a valid one
907                                for test_day in (1..=day).rev() {
908                                    if let Some(date) =
909                                        NaiveDate::from_ymd_opt(year, month as u32, test_day)
910                                    {
911                                        return date;
912                                    }
913                                }
914                                // This should never happen, but fallback to day 28 as safety
915                                NaiveDate::from_ymd_opt(year, month as u32, 28).unwrap()
916                            });
917
918                        Utc.from_utc_datetime(&target_date.and_time(base_date.time()))
919                    }
920                    "year" | "years" => {
921                        let new_year = base_date.year() + amount as i32;
922                        let target_date =
923                            NaiveDate::from_ymd_opt(new_year, base_date.month(), base_date.day())
924                                .unwrap_or_else(|| {
925                                    // Handle Feb 29 in non-leap years
926                                    NaiveDate::from_ymd_opt(new_year, base_date.month(), 28)
927                                        .unwrap()
928                                });
929                        Utc.from_utc_datetime(&target_date.and_time(base_date.time()))
930                    }
931                    "hour" | "hours" => base_date + chrono::Duration::hours(amount),
932                    "minute" | "minutes" => base_date + chrono::Duration::minutes(amount),
933                    "second" | "seconds" => base_date + chrono::Duration::seconds(amount),
934                    _ => {
935                        return Err(anyhow!(
936                            "Unknown DATEADD unit: {}. Use: day, month, year, hour, minute, second",
937                            unit
938                        ))
939                    }
940                };
941
942                // Return as datetime string
943                Ok(DataValue::DateTime(
944                    result_date.format("%Y-%m-%d %H:%M:%S").to_string(),
945                ))
946            }
947            "TEXTJOIN" => {
948                if args.len() < 3 {
949                    return Err(anyhow!("TEXTJOIN requires at least 3 arguments: delimiter, ignore_empty, text1, [text2, ...]"));
950                }
951
952                // First argument: delimiter
953                let delimiter = match self.evaluate(&args[0], row_index)? {
954                    DataValue::String(s) => s,
955                    DataValue::InternedString(s) => s.to_string(),
956                    DataValue::Integer(n) => n.to_string(),
957                    DataValue::Float(f) => f.to_string(),
958                    DataValue::Boolean(b) => b.to_string(),
959                    DataValue::Null => String::new(),
960                    _ => String::new(),
961                };
962
963                // Second argument: ignore_empty (treat as boolean - 0 is false, anything else is true)
964                let ignore_empty = match self.evaluate(&args[1], row_index)? {
965                    DataValue::Integer(n) => n != 0,
966                    DataValue::Float(f) => f != 0.0,
967                    DataValue::Boolean(b) => b,
968                    DataValue::String(s) => {
969                        !s.is_empty() && s != "0" && s.to_lowercase() != "false"
970                    }
971                    DataValue::InternedString(s) => {
972                        !s.is_empty() && s.as_str() != "0" && s.to_lowercase() != "false"
973                    }
974                    DataValue::Null => false,
975                    _ => true,
976                };
977
978                // Remaining arguments: values to join
979                let mut values = Vec::new();
980                for i in 2..args.len() {
981                    let value = self.evaluate(&args[i], row_index)?;
982                    let string_value = match value {
983                        DataValue::String(s) => Some(s),
984                        DataValue::InternedString(s) => Some(s.to_string()),
985                        DataValue::Integer(n) => Some(n.to_string()),
986                        DataValue::Float(f) => Some(f.to_string()),
987                        DataValue::Boolean(b) => Some(b.to_string()),
988                        DataValue::DateTime(dt) => Some(dt),
989                        DataValue::Null => {
990                            if ignore_empty {
991                                None
992                            } else {
993                                Some(String::new())
994                            }
995                        }
996                        _ => {
997                            if ignore_empty {
998                                None
999                            } else {
1000                                Some(String::new())
1001                            }
1002                        }
1003                    };
1004
1005                    if let Some(s) = string_value {
1006                        if !ignore_empty || !s.is_empty() {
1007                            values.push(s);
1008                        }
1009                    }
1010                }
1011
1012                Ok(DataValue::String(values.join(&delimiter)))
1013            }
1014            _ => Err(anyhow!("Unknown function: {}", name)),
1015        }
1016    }
1017
1018    /// Evaluate a method call on a column (e.g., column.Trim())
1019    fn evaluate_method_call(
1020        &self,
1021        object: &str,
1022        method: &str,
1023        args: &[SqlExpression],
1024        row_index: usize,
1025    ) -> Result<DataValue> {
1026        // Get column value
1027        let col_index = self.table.get_column_index(object).ok_or_else(|| {
1028            let suggestion = self.find_similar_column(object);
1029            match suggestion {
1030                Some(similar) => {
1031                    anyhow!("Column '{}' not found. Did you mean '{}'?", object, similar)
1032                }
1033                None => anyhow!("Column '{}' not found", object),
1034            }
1035        })?;
1036
1037        let cell_value = self.table.get_value(row_index, col_index).cloned();
1038
1039        self.evaluate_method_on_value(
1040            &cell_value.unwrap_or(DataValue::Null),
1041            method,
1042            args,
1043            row_index,
1044        )
1045    }
1046
1047    /// Evaluate a method on a value
1048    fn evaluate_method_on_value(
1049        &self,
1050        value: &DataValue,
1051        method: &str,
1052        args: &[SqlExpression],
1053        row_index: usize,
1054    ) -> Result<DataValue> {
1055        match method.to_lowercase().as_str() {
1056            "trim" | "trimstart" | "trimend" => {
1057                if !args.is_empty() {
1058                    return Err(anyhow!("{} takes no arguments", method));
1059                }
1060
1061                // Convert value to string and apply trim
1062                let str_val = match value {
1063                    DataValue::String(s) => s.clone(),
1064                    DataValue::InternedString(s) => s.to_string(),
1065                    DataValue::Integer(n) => n.to_string(),
1066                    DataValue::Float(f) => f.to_string(),
1067                    DataValue::Boolean(b) => b.to_string(),
1068                    DataValue::DateTime(dt) => dt.clone(),
1069                    DataValue::Null => return Ok(DataValue::Null),
1070                };
1071
1072                let result = match method.to_lowercase().as_str() {
1073                    "trim" => str_val.trim().to_string(),
1074                    "trimstart" => str_val.trim_start().to_string(),
1075                    "trimend" => str_val.trim_end().to_string(),
1076                    _ => unreachable!(),
1077                };
1078
1079                Ok(DataValue::String(result))
1080            }
1081            "length" => {
1082                if !args.is_empty() {
1083                    return Err(anyhow!("Length takes no arguments"));
1084                }
1085
1086                // Get string length
1087                let len = match value {
1088                    DataValue::String(s) => s.len(),
1089                    DataValue::InternedString(s) => s.len(),
1090                    DataValue::Integer(n) => n.to_string().len(),
1091                    DataValue::Float(f) => f.to_string().len(),
1092                    DataValue::Boolean(b) => b.to_string().len(),
1093                    DataValue::DateTime(dt) => dt.len(),
1094                    DataValue::Null => return Ok(DataValue::Integer(0)),
1095                };
1096
1097                Ok(DataValue::Integer(len as i64))
1098            }
1099            "indexof" => {
1100                if args.len() != 1 {
1101                    return Err(anyhow!("IndexOf requires exactly 1 argument"));
1102                }
1103
1104                // Get the search string from args
1105                let search_str = match self.evaluate(&args[0], row_index)? {
1106                    DataValue::String(s) => s,
1107                    DataValue::InternedString(s) => s.to_string(),
1108                    DataValue::Integer(n) => n.to_string(),
1109                    DataValue::Float(f) => f.to_string(),
1110                    _ => return Err(anyhow!("IndexOf argument must be a string")),
1111                };
1112
1113                // Convert value to string and find index
1114                let str_val = match value {
1115                    DataValue::String(s) => s.clone(),
1116                    DataValue::InternedString(s) => s.to_string(),
1117                    DataValue::Integer(n) => n.to_string(),
1118                    DataValue::Float(f) => f.to_string(),
1119                    DataValue::Boolean(b) => b.to_string(),
1120                    DataValue::DateTime(dt) => dt.clone(),
1121                    DataValue::Null => return Ok(DataValue::Integer(-1)),
1122                };
1123
1124                let index = str_val.find(&search_str).map(|i| i as i64).unwrap_or(-1);
1125
1126                Ok(DataValue::Integer(index))
1127            }
1128            "contains" => {
1129                if args.len() != 1 {
1130                    return Err(anyhow!("Contains requires exactly 1 argument"));
1131                }
1132
1133                // Get the search string from args
1134                let search_str = match self.evaluate(&args[0], row_index)? {
1135                    DataValue::String(s) => s,
1136                    DataValue::InternedString(s) => s.to_string(),
1137                    DataValue::Integer(n) => n.to_string(),
1138                    DataValue::Float(f) => f.to_string(),
1139                    _ => return Err(anyhow!("Contains argument must be a string")),
1140                };
1141
1142                // Convert value to string and check contains
1143                let str_val = match value {
1144                    DataValue::String(s) => s.clone(),
1145                    DataValue::InternedString(s) => s.to_string(),
1146                    DataValue::Integer(n) => n.to_string(),
1147                    DataValue::Float(f) => f.to_string(),
1148                    DataValue::Boolean(b) => b.to_string(),
1149                    DataValue::DateTime(dt) => dt.clone(),
1150                    DataValue::Null => return Ok(DataValue::Boolean(false)),
1151                };
1152
1153                // Case-insensitive search
1154                let result = str_val.to_lowercase().contains(&search_str.to_lowercase());
1155                Ok(DataValue::Boolean(result))
1156            }
1157            "startswith" => {
1158                if args.len() != 1 {
1159                    return Err(anyhow!("StartsWith requires exactly 1 argument"));
1160                }
1161
1162                // Get the prefix from args
1163                let prefix = match self.evaluate(&args[0], row_index)? {
1164                    DataValue::String(s) => s,
1165                    DataValue::InternedString(s) => s.to_string(),
1166                    DataValue::Integer(n) => n.to_string(),
1167                    DataValue::Float(f) => f.to_string(),
1168                    _ => return Err(anyhow!("StartsWith argument must be a string")),
1169                };
1170
1171                // Convert value to string and check starts_with
1172                let str_val = match value {
1173                    DataValue::String(s) => s.clone(),
1174                    DataValue::InternedString(s) => s.to_string(),
1175                    DataValue::Integer(n) => n.to_string(),
1176                    DataValue::Float(f) => f.to_string(),
1177                    DataValue::Boolean(b) => b.to_string(),
1178                    DataValue::DateTime(dt) => dt.clone(),
1179                    DataValue::Null => return Ok(DataValue::Boolean(false)),
1180                };
1181
1182                // Case-insensitive check
1183                let result = str_val.to_lowercase().starts_with(&prefix.to_lowercase());
1184                Ok(DataValue::Boolean(result))
1185            }
1186            "endswith" => {
1187                if args.len() != 1 {
1188                    return Err(anyhow!("EndsWith requires exactly 1 argument"));
1189                }
1190
1191                // Get the suffix from args
1192                let suffix = match self.evaluate(&args[0], row_index)? {
1193                    DataValue::String(s) => s,
1194                    DataValue::InternedString(s) => s.to_string(),
1195                    DataValue::Integer(n) => n.to_string(),
1196                    DataValue::Float(f) => f.to_string(),
1197                    _ => return Err(anyhow!("EndsWith argument must be a string")),
1198                };
1199
1200                // Convert value to string and check ends_with
1201                let str_val = match value {
1202                    DataValue::String(s) => s.clone(),
1203                    DataValue::InternedString(s) => s.to_string(),
1204                    DataValue::Integer(n) => n.to_string(),
1205                    DataValue::Float(f) => f.to_string(),
1206                    DataValue::Boolean(b) => b.to_string(),
1207                    DataValue::DateTime(dt) => dt.clone(),
1208                    DataValue::Null => return Ok(DataValue::Boolean(false)),
1209                };
1210
1211                // Case-insensitive check
1212                let result = str_val.to_lowercase().ends_with(&suffix.to_lowercase());
1213                Ok(DataValue::Boolean(result))
1214            }
1215            _ => Err(anyhow!("Unsupported method: {}", method)),
1216        }
1217    }
1218
1219    /// Evaluate a CASE expression
1220    fn evaluate_case_expression(
1221        &self,
1222        when_branches: &[crate::sql::recursive_parser::WhenBranch],
1223        else_branch: &Option<Box<SqlExpression>>,
1224        row_index: usize,
1225    ) -> Result<DataValue> {
1226        debug!(
1227            "ArithmeticEvaluator: evaluating CASE expression for row {}",
1228            row_index
1229        );
1230
1231        // Evaluate each WHEN condition in order
1232        for branch in when_branches {
1233            // Evaluate the condition as a boolean
1234            let condition_result = self.evaluate_condition_as_bool(&branch.condition, row_index)?;
1235
1236            if condition_result {
1237                debug!("CASE: WHEN condition matched, evaluating result expression");
1238                return self.evaluate(&branch.result, row_index);
1239            }
1240        }
1241
1242        // If no WHEN condition matched, evaluate ELSE clause (or return NULL)
1243        match else_branch {
1244            Some(else_expr) => {
1245                debug!("CASE: No WHEN matched, evaluating ELSE expression");
1246                self.evaluate(else_expr, row_index)
1247            }
1248            None => {
1249                debug!("CASE: No WHEN matched and no ELSE, returning NULL");
1250                Ok(DataValue::Null)
1251            }
1252        }
1253    }
1254
1255    /// Helper method to evaluate an expression as a boolean (for CASE WHEN conditions)
1256    fn evaluate_condition_as_bool(&self, expr: &SqlExpression, row_index: usize) -> Result<bool> {
1257        let value = self.evaluate(expr, row_index)?;
1258
1259        match value {
1260            DataValue::Boolean(b) => Ok(b),
1261            DataValue::Integer(i) => Ok(i != 0),
1262            DataValue::Float(f) => Ok(f != 0.0),
1263            DataValue::Null => Ok(false),
1264            DataValue::String(s) => Ok(!s.is_empty()),
1265            DataValue::InternedString(s) => Ok(!s.is_empty()),
1266            _ => Ok(true), // Other types are considered truthy
1267        }
1268    }
1269}
1270
1271#[cfg(test)]
1272mod tests {
1273    use super::*;
1274    use crate::data::datatable::{DataColumn, DataRow};
1275
1276    fn create_test_table() -> DataTable {
1277        let mut table = DataTable::new("test");
1278        table.add_column(DataColumn::new("a"));
1279        table.add_column(DataColumn::new("b"));
1280        table.add_column(DataColumn::new("c"));
1281
1282        table
1283            .add_row(DataRow::new(vec![
1284                DataValue::Integer(10),
1285                DataValue::Float(2.5),
1286                DataValue::Integer(4),
1287            ]))
1288            .unwrap();
1289
1290        table
1291    }
1292
1293    #[test]
1294    fn test_evaluate_column() {
1295        let table = create_test_table();
1296        let evaluator = ArithmeticEvaluator::new(&table);
1297
1298        let expr = SqlExpression::Column("a".to_string());
1299        let result = evaluator.evaluate(&expr, 0).unwrap();
1300        assert_eq!(result, DataValue::Integer(10));
1301    }
1302
1303    #[test]
1304    fn test_evaluate_number_literal() {
1305        let table = create_test_table();
1306        let evaluator = ArithmeticEvaluator::new(&table);
1307
1308        let expr = SqlExpression::NumberLiteral("42".to_string());
1309        let result = evaluator.evaluate(&expr, 0).unwrap();
1310        assert_eq!(result, DataValue::Integer(42));
1311
1312        let expr = SqlExpression::NumberLiteral("3.14".to_string());
1313        let result = evaluator.evaluate(&expr, 0).unwrap();
1314        assert_eq!(result, DataValue::Float(3.14));
1315    }
1316
1317    #[test]
1318    fn test_add_values() {
1319        let table = create_test_table();
1320        let evaluator = ArithmeticEvaluator::new(&table);
1321
1322        // Integer + Integer
1323        let result = evaluator
1324            .add_values(&DataValue::Integer(5), &DataValue::Integer(3))
1325            .unwrap();
1326        assert_eq!(result, DataValue::Integer(8));
1327
1328        // Integer + Float
1329        let result = evaluator
1330            .add_values(&DataValue::Integer(5), &DataValue::Float(2.5))
1331            .unwrap();
1332        assert_eq!(result, DataValue::Float(7.5));
1333    }
1334
1335    #[test]
1336    fn test_multiply_values() {
1337        let table = create_test_table();
1338        let evaluator = ArithmeticEvaluator::new(&table);
1339
1340        // Integer * Float
1341        let result = evaluator
1342            .multiply_values(&DataValue::Integer(4), &DataValue::Float(2.5))
1343            .unwrap();
1344        assert_eq!(result, DataValue::Float(10.0));
1345    }
1346
1347    #[test]
1348    fn test_divide_values() {
1349        let table = create_test_table();
1350        let evaluator = ArithmeticEvaluator::new(&table);
1351
1352        // Exact division
1353        let result = evaluator
1354            .divide_values(&DataValue::Integer(10), &DataValue::Integer(2))
1355            .unwrap();
1356        assert_eq!(result, DataValue::Integer(5));
1357
1358        // Non-exact division
1359        let result = evaluator
1360            .divide_values(&DataValue::Integer(10), &DataValue::Integer(3))
1361            .unwrap();
1362        assert_eq!(result, DataValue::Float(10.0 / 3.0));
1363    }
1364
1365    #[test]
1366    fn test_division_by_zero() {
1367        let table = create_test_table();
1368        let evaluator = ArithmeticEvaluator::new(&table);
1369
1370        let result = evaluator.divide_values(&DataValue::Integer(10), &DataValue::Integer(0));
1371        assert!(result.is_err());
1372        assert!(result.unwrap_err().to_string().contains("Division by zero"));
1373    }
1374
1375    #[test]
1376    fn test_binary_op_expression() {
1377        let table = create_test_table();
1378        let evaluator = ArithmeticEvaluator::new(&table);
1379
1380        // a * b where a=10, b=2.5
1381        let expr = SqlExpression::BinaryOp {
1382            left: Box::new(SqlExpression::Column("a".to_string())),
1383            op: "*".to_string(),
1384            right: Box::new(SqlExpression::Column("b".to_string())),
1385        };
1386
1387        let result = evaluator.evaluate(&expr, 0).unwrap();
1388        assert_eq!(result, DataValue::Float(25.0));
1389    }
1390}