sql_cli/data/
arithmetic_evaluator.rs

1use crate::data::data_view::DataView;
2use crate::data::datatable::{DataTable, DataValue};
3use crate::sql::aggregates::AggregateRegistry;
4use crate::sql::functions::FunctionRegistry;
5use crate::sql::recursive_parser::{SqlExpression, WindowSpec};
6use crate::sql::window_context::WindowContext;
7use anyhow::{anyhow, Result};
8use std::collections::HashMap;
9use std::sync::Arc;
10use tracing::debug;
11
12/// Evaluates SQL expressions to compute `DataValues` (for SELECT clauses)
13/// This is different from `RecursiveWhereEvaluator` which returns boolean
14pub struct ArithmeticEvaluator<'a> {
15    table: &'a DataTable,
16    date_notation: String,
17    function_registry: Arc<FunctionRegistry>,
18    aggregate_registry: Arc<AggregateRegistry>,
19    visible_rows: Option<Vec<usize>>, // For aggregate functions on filtered views
20    window_contexts: HashMap<String, Arc<WindowContext>>, // Cache window contexts by spec
21}
22
23impl<'a> ArithmeticEvaluator<'a> {
24    #[must_use]
25    pub fn new(table: &'a DataTable) -> Self {
26        Self {
27            table,
28            date_notation: "us".to_string(),
29            function_registry: Arc::new(FunctionRegistry::new()),
30            aggregate_registry: Arc::new(AggregateRegistry::new()),
31            visible_rows: None,
32            window_contexts: HashMap::new(),
33        }
34    }
35
36    #[must_use]
37    pub fn with_date_notation(table: &'a DataTable, date_notation: String) -> Self {
38        Self {
39            table,
40            date_notation,
41            function_registry: Arc::new(FunctionRegistry::new()),
42            aggregate_registry: Arc::new(AggregateRegistry::new()),
43            visible_rows: None,
44            window_contexts: HashMap::new(),
45        }
46    }
47
48    /// Set visible rows for aggregate functions (for filtered views)
49    #[must_use]
50    pub fn with_visible_rows(mut self, rows: Vec<usize>) -> Self {
51        self.visible_rows = Some(rows);
52        self
53    }
54
55    #[must_use]
56    pub fn with_date_notation_and_registry(
57        table: &'a DataTable,
58        date_notation: String,
59        function_registry: Arc<FunctionRegistry>,
60    ) -> Self {
61        Self {
62            table,
63            date_notation,
64            function_registry,
65            aggregate_registry: Arc::new(AggregateRegistry::new()),
66            visible_rows: None,
67            window_contexts: HashMap::new(),
68        }
69    }
70
71    /// Find a column name similar to the given name using edit distance
72    fn find_similar_column(&self, name: &str) -> Option<String> {
73        let columns = self.table.column_names();
74        let mut best_match: Option<(String, usize)> = None;
75
76        for col in columns {
77            let distance = self.edit_distance(&col.to_lowercase(), &name.to_lowercase());
78            // Only suggest if distance is small (likely a typo)
79            // Allow up to 3 edits for longer names
80            let max_distance = if name.len() > 10 { 3 } else { 2 };
81            if distance <= max_distance {
82                match &best_match {
83                    None => best_match = Some((col, distance)),
84                    Some((_, best_dist)) if distance < *best_dist => {
85                        best_match = Some((col, distance));
86                    }
87                    _ => {}
88                }
89            }
90        }
91
92        best_match.map(|(name, _)| name)
93    }
94
95    /// Calculate Levenshtein edit distance between two strings
96    fn edit_distance(&self, s1: &str, s2: &str) -> usize {
97        // Use the shared implementation from string_methods
98        crate::sql::functions::string_methods::EditDistanceFunction::calculate_edit_distance(s1, s2)
99    }
100
101    /// Evaluate an SQL expression to produce a `DataValue`
102    pub fn evaluate(&mut self, expr: &SqlExpression, row_index: usize) -> Result<DataValue> {
103        debug!(
104            "ArithmeticEvaluator: evaluating {:?} for row {}",
105            expr, row_index
106        );
107
108        match expr {
109            SqlExpression::Column(column_name) => self.evaluate_column(column_name, row_index),
110            SqlExpression::StringLiteral(s) => Ok(DataValue::String(s.clone())),
111            SqlExpression::BooleanLiteral(b) => Ok(DataValue::Boolean(*b)),
112            SqlExpression::NumberLiteral(n) => self.evaluate_number_literal(n),
113            SqlExpression::Null => Ok(DataValue::Null),
114            SqlExpression::BinaryOp { left, op, right } => {
115                self.evaluate_binary_op(left, op, right, row_index)
116            }
117            SqlExpression::FunctionCall { name, args } => {
118                self.evaluate_function(name, args, row_index)
119            }
120            SqlExpression::WindowFunction {
121                name,
122                args,
123                window_spec,
124            } => self.evaluate_window_function(name, args, window_spec, row_index),
125            SqlExpression::MethodCall {
126                object,
127                method,
128                args,
129            } => self.evaluate_method_call(object, method, args, row_index),
130            SqlExpression::ChainedMethodCall { base, method, args } => {
131                // Evaluate the base expression first, then apply the method
132                let base_value = self.evaluate(base, row_index)?;
133                self.evaluate_method_on_value(&base_value, method, args, row_index)
134            }
135            SqlExpression::CaseExpression {
136                when_branches,
137                else_branch,
138            } => self.evaluate_case_expression(when_branches, else_branch, row_index),
139            _ => Err(anyhow!(
140                "Unsupported expression type for arithmetic evaluation: {:?}",
141                expr
142            )),
143        }
144    }
145
146    /// Evaluate a column reference
147    fn evaluate_column(&self, column_name: &str, row_index: usize) -> Result<DataValue> {
148        let col_index = self.table.get_column_index(column_name).ok_or_else(|| {
149            let suggestion = self.find_similar_column(column_name);
150            match suggestion {
151                Some(similar) => anyhow!(
152                    "Column '{}' not found. Did you mean '{}'?",
153                    column_name,
154                    similar
155                ),
156                None => anyhow!("Column '{}' not found", column_name),
157            }
158        })?;
159
160        if row_index >= self.table.row_count() {
161            return Err(anyhow!("Row index {} out of bounds", row_index));
162        }
163
164        let row = self
165            .table
166            .get_row(row_index)
167            .ok_or_else(|| anyhow!("Row {} not found", row_index))?;
168
169        let value = row
170            .get(col_index)
171            .ok_or_else(|| anyhow!("Column index {} out of bounds for row", col_index))?;
172
173        Ok(value.clone())
174    }
175
176    /// Evaluate a number literal (handles both integers and floats)
177    fn evaluate_number_literal(&self, number_str: &str) -> Result<DataValue> {
178        // Try to parse as integer first
179        if let Ok(int_val) = number_str.parse::<i64>() {
180            return Ok(DataValue::Integer(int_val));
181        }
182
183        // If that fails, try as float
184        if let Ok(float_val) = number_str.parse::<f64>() {
185            return Ok(DataValue::Float(float_val));
186        }
187
188        Err(anyhow!("Invalid number literal: {}", number_str))
189    }
190
191    /// Evaluate a binary operation (arithmetic)
192    fn evaluate_binary_op(
193        &mut self,
194        left: &SqlExpression,
195        op: &str,
196        right: &SqlExpression,
197        row_index: usize,
198    ) -> Result<DataValue> {
199        let left_val = self.evaluate(left, row_index)?;
200        let right_val = self.evaluate(right, row_index)?;
201
202        debug!(
203            "ArithmeticEvaluator: {} {} {}",
204            self.format_value(&left_val),
205            op,
206            self.format_value(&right_val)
207        );
208
209        match op {
210            "+" => self.add_values(&left_val, &right_val),
211            "-" => self.subtract_values(&left_val, &right_val),
212            "*" => self.multiply_values(&left_val, &right_val),
213            "/" => self.divide_values(&left_val, &right_val),
214            // Comparison operators (return boolean results)
215            ">" => self.compare_values(&left_val, &right_val, |a, b| a > b),
216            "<" => self.compare_values(&left_val, &right_val, |a, b| a < b),
217            ">=" => self.compare_values(&left_val, &right_val, |a, b| a >= b),
218            "<=" => self.compare_values(&left_val, &right_val, |a, b| a <= b),
219            "=" => self.compare_values(&left_val, &right_val, |a, b| a == b),
220            "!=" | "<>" => self.compare_values(&left_val, &right_val, |a, b| a != b),
221            // IS NULL / IS NOT NULL operators
222            "IS NULL" => Ok(DataValue::Boolean(matches!(left_val, DataValue::Null))),
223            "IS NOT NULL" => Ok(DataValue::Boolean(!matches!(left_val, DataValue::Null))),
224            _ => Err(anyhow!("Unsupported arithmetic operator: {}", op)),
225        }
226    }
227
228    /// Add two `DataValues` with type coercion
229    fn add_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
230        // NULL handling - any operation with NULL returns NULL
231        if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
232            return Ok(DataValue::Null);
233        }
234
235        match (left, right) {
236            (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a + b)),
237            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 + b)),
238            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a + *b as f64)),
239            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a + b)),
240            _ => Err(anyhow!("Cannot add {:?} and {:?}", left, right)),
241        }
242    }
243
244    /// Subtract two `DataValues` with type coercion
245    fn subtract_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
246        // NULL handling - any operation with NULL returns NULL
247        if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
248            return Ok(DataValue::Null);
249        }
250
251        match (left, right) {
252            (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a - b)),
253            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 - b)),
254            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a - *b as f64)),
255            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a - b)),
256            _ => Err(anyhow!("Cannot subtract {:?} and {:?}", left, right)),
257        }
258    }
259
260    /// Multiply two `DataValues` with type coercion
261    fn multiply_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
262        // NULL handling - any operation with NULL returns NULL
263        if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
264            return Ok(DataValue::Null);
265        }
266
267        match (left, right) {
268            (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a * b)),
269            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 * b)),
270            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a * *b as f64)),
271            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a * b)),
272            _ => Err(anyhow!("Cannot multiply {:?} and {:?}", left, right)),
273        }
274    }
275
276    /// Divide two `DataValues` with type coercion
277    fn divide_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
278        // NULL handling - any operation with NULL returns NULL
279        if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
280            return Ok(DataValue::Null);
281        }
282
283        // Check for division by zero first
284        let is_zero = match right {
285            DataValue::Integer(0) => true,
286            DataValue::Float(f) if *f == 0.0 => true, // Only check for exact zero, not epsilon
287            _ => false,
288        };
289
290        if is_zero {
291            return Err(anyhow!("Division by zero"));
292        }
293
294        match (left, right) {
295            (DataValue::Integer(a), DataValue::Integer(b)) => {
296                // Integer division - if result is exact, keep as int, otherwise promote to float
297                if a % b == 0 {
298                    Ok(DataValue::Integer(a / b))
299                } else {
300                    Ok(DataValue::Float(*a as f64 / *b as f64))
301                }
302            }
303            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 / b)),
304            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a / *b as f64)),
305            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a / b)),
306            _ => Err(anyhow!("Cannot divide {:?} and {:?}", left, right)),
307        }
308    }
309
310    /// Format a `DataValue` for debug output
311    fn format_value(&self, value: &DataValue) -> String {
312        match value {
313            DataValue::Integer(i) => i.to_string(),
314            DataValue::Float(f) => f.to_string(),
315            DataValue::String(s) => format!("'{s}'"),
316            _ => format!("{value:?}"),
317        }
318    }
319
320    /// Compare two `DataValues` using the provided comparison function
321    fn compare_values<F>(&self, left: &DataValue, right: &DataValue, op: F) -> Result<DataValue>
322    where
323        F: Fn(f64, f64) -> bool,
324    {
325        debug!(
326            "ArithmeticEvaluator: comparing values {:?} and {:?}",
327            left, right
328        );
329
330        let result = match (left, right) {
331            // Integer comparisons
332            (DataValue::Integer(a), DataValue::Integer(b)) => op(*a as f64, *b as f64),
333            (DataValue::Integer(a), DataValue::Float(b)) => op(*a as f64, *b),
334            (DataValue::Float(a), DataValue::Integer(b)) => op(*a, *b as f64),
335            (DataValue::Float(a), DataValue::Float(b)) => op(*a, *b),
336
337            // String comparisons (lexicographic)
338            (DataValue::String(a), DataValue::String(b)) => {
339                let a_num = a.parse::<f64>();
340                let b_num = b.parse::<f64>();
341                match (a_num, b_num) {
342                    (Ok(a_val), Ok(b_val)) => op(a_val, b_val), // Both are numbers
343                    _ => op(a.len() as f64, b.len() as f64),    // Fallback to length comparison
344                }
345            }
346            (DataValue::InternedString(a), DataValue::InternedString(b)) => {
347                let a_num = a.parse::<f64>();
348                let b_num = b.parse::<f64>();
349                match (a_num, b_num) {
350                    (Ok(a_val), Ok(b_val)) => op(a_val, b_val), // Both are numbers
351                    _ => op(a.len() as f64, b.len() as f64),    // Fallback to length comparison
352                }
353            }
354            (DataValue::String(a), DataValue::InternedString(b)) => {
355                let a_num = a.parse::<f64>();
356                let b_num = b.parse::<f64>();
357                match (a_num, b_num) {
358                    (Ok(a_val), Ok(b_val)) => op(a_val, b_val), // Both are numbers
359                    _ => op(a.len() as f64, b.len() as f64),    // Fallback to length comparison
360                }
361            }
362            (DataValue::InternedString(a), DataValue::String(b)) => {
363                let a_num = a.parse::<f64>();
364                let b_num = b.parse::<f64>();
365                match (a_num, b_num) {
366                    (Ok(a_val), Ok(b_val)) => op(a_val, b_val), // Both are numbers
367                    _ => op(a.len() as f64, b.len() as f64),    // Fallback to length comparison
368                }
369            }
370
371            // Mixed type comparisons (try to convert to numbers)
372            (DataValue::String(a), DataValue::Integer(b)) => {
373                match a.parse::<f64>() {
374                    Ok(a_val) => op(a_val, *b as f64),
375                    Err(_) => false, // String can't be compared with number
376                }
377            }
378            (DataValue::Integer(a), DataValue::String(b)) => {
379                match b.parse::<f64>() {
380                    Ok(b_val) => op(*a as f64, b_val),
381                    Err(_) => false, // String can't be compared with number
382                }
383            }
384            (DataValue::String(a), DataValue::Float(b)) => match a.parse::<f64>() {
385                Ok(a_val) => op(a_val, *b),
386                Err(_) => false,
387            },
388            (DataValue::Float(a), DataValue::String(b)) => match b.parse::<f64>() {
389                Ok(b_val) => op(*a, b_val),
390                Err(_) => false,
391            },
392
393            // NULL comparisons
394            (DataValue::Null, _) | (_, DataValue::Null) => false,
395
396            // Boolean comparisons
397            (DataValue::Boolean(a), DataValue::Boolean(b)) => {
398                op(if *a { 1.0 } else { 0.0 }, if *b { 1.0 } else { 0.0 })
399            }
400
401            _ => {
402                debug!(
403                    "ArithmeticEvaluator: unsupported comparison between {:?} and {:?}",
404                    left, right
405                );
406                false
407            }
408        };
409
410        debug!("ArithmeticEvaluator: comparison result: {}", result);
411        Ok(DataValue::Boolean(result))
412    }
413
414    /// Evaluate a function call
415    fn evaluate_function(
416        &mut self,
417        name: &str,
418        args: &[SqlExpression],
419        row_index: usize,
420    ) -> Result<DataValue> {
421        // Check if this is an aggregate function
422        let name_upper = name.to_uppercase();
423
424        // Handle COUNT(*) special case
425        if name_upper == "COUNT" && args.len() == 1 {
426            match &args[0] {
427                SqlExpression::Column(col) if col == "*" => {
428                    // COUNT(*) - count all rows (or visible rows if filtered)
429                    let count = if let Some(ref visible) = self.visible_rows {
430                        visible.len() as i64
431                    } else {
432                        self.table.rows.len() as i64
433                    };
434                    return Ok(DataValue::Integer(count));
435                }
436                SqlExpression::StringLiteral(s) if s == "*" => {
437                    // COUNT(*) parsed as StringLiteral
438                    let count = if let Some(ref visible) = self.visible_rows {
439                        visible.len() as i64
440                    } else {
441                        self.table.rows.len() as i64
442                    };
443                    return Ok(DataValue::Integer(count));
444                }
445                _ => {
446                    // COUNT(column) - will be handled below
447                }
448            }
449        }
450
451        // Check aggregate registry
452        if self.aggregate_registry.get(&name_upper).is_some() {
453            // Determine which rows to process first
454            let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
455                visible.clone()
456            } else {
457                (0..self.table.rows.len()).collect()
458            };
459
460            // Evaluate arguments first if needed (to avoid borrow issues)
461            let values = if !args.is_empty()
462                && !(args.len() == 1 && matches!(&args[0], SqlExpression::Column(c) if c == "*"))
463            {
464                // Evaluate the argument expression for each row
465                let mut vals = Vec::new();
466                for &row_idx in &rows_to_process {
467                    let value = self.evaluate(&args[0], row_idx)?;
468                    vals.push(value);
469                }
470                Some(vals)
471            } else {
472                None
473            };
474
475            // Now get the aggregate function and process
476            let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
477            let mut state = agg_func.init();
478
479            if let Some(values) = values {
480                // Use evaluated values
481                for value in &values {
482                    agg_func.accumulate(&mut state, value)?;
483                }
484            } else {
485                // COUNT(*) case
486                for _ in &rows_to_process {
487                    agg_func.accumulate(&mut state, &DataValue::Integer(1))?;
488                }
489            }
490
491            return Ok(agg_func.finalize(state));
492        }
493
494        // First check if this function exists in the registry
495        if self.function_registry.get(name).is_some() {
496            // Evaluate all arguments first to avoid borrow issues
497            let mut evaluated_args = Vec::new();
498            for arg in args {
499                evaluated_args.push(self.evaluate(arg, row_index)?);
500            }
501
502            // Get the function and call it
503            let func = self.function_registry.get(name).unwrap();
504            return func.evaluate(&evaluated_args);
505        }
506
507        // If not in registry, return error for unknown function
508        Err(anyhow!("Unknown function: {}", name))
509    }
510
511    /// Get or create a WindowContext for the given specification
512    fn get_or_create_window_context(&mut self, spec: &WindowSpec) -> Result<Arc<WindowContext>> {
513        // Create a key for caching based on the spec
514        let key = format!("{:?}", spec);
515
516        if let Some(context) = self.window_contexts.get(&key) {
517            return Ok(Arc::clone(context));
518        }
519
520        // Create a DataView from the table (with visible rows if filtered)
521        let data_view = if let Some(ref visible_rows) = self.visible_rows {
522            // Create a filtered view
523            let mut view = DataView::new(Arc::new(self.table.clone()));
524            // Apply filtering based on visible rows
525            // Note: This is a simplified approach - in production we'd need proper filtering
526            view
527        } else {
528            DataView::new(Arc::new(self.table.clone()))
529        };
530
531        // Create the WindowContext
532        let context = WindowContext::new(
533            Arc::new(data_view),
534            spec.partition_by.clone(),
535            spec.order_by.clone(),
536        )?;
537
538        let context = Arc::new(context);
539        self.window_contexts.insert(key, Arc::clone(&context));
540        Ok(context)
541    }
542
543    /// Evaluate a window function
544    fn evaluate_window_function(
545        &mut self,
546        name: &str,
547        args: &[SqlExpression],
548        spec: &WindowSpec,
549        row_index: usize,
550    ) -> Result<DataValue> {
551        let context = self.get_or_create_window_context(spec)?;
552        let name_upper = name.to_uppercase();
553
554        match name_upper.as_str() {
555            "LAG" => {
556                // LAG(column, offset, default)
557                if args.is_empty() {
558                    return Err(anyhow!("LAG requires at least 1 argument"));
559                }
560
561                // Get column name
562                let column = match &args[0] {
563                    SqlExpression::Column(col) => col.clone(),
564                    _ => return Err(anyhow!("LAG first argument must be a column")),
565                };
566
567                // Get offset (default 1)
568                let offset = if args.len() > 1 {
569                    match self.evaluate(&args[1], row_index)? {
570                        DataValue::Integer(i) => i as i32,
571                        _ => return Err(anyhow!("LAG offset must be an integer")),
572                    }
573                } else {
574                    1
575                };
576
577                // Get value at offset
578                Ok(context
579                    .get_offset_value(row_index, -offset, &column)
580                    .unwrap_or(DataValue::Null))
581            }
582            "LEAD" => {
583                // LEAD(column, offset, default)
584                if args.is_empty() {
585                    return Err(anyhow!("LEAD requires at least 1 argument"));
586                }
587
588                // Get column name
589                let column = match &args[0] {
590                    SqlExpression::Column(col) => col.clone(),
591                    _ => return Err(anyhow!("LEAD first argument must be a column")),
592                };
593
594                // Get offset (default 1)
595                let offset = if args.len() > 1 {
596                    match self.evaluate(&args[1], row_index)? {
597                        DataValue::Integer(i) => i as i32,
598                        _ => return Err(anyhow!("LEAD offset must be an integer")),
599                    }
600                } else {
601                    1
602                };
603
604                // Get value at offset
605                Ok(context
606                    .get_offset_value(row_index, offset, &column)
607                    .unwrap_or(DataValue::Null))
608            }
609            "ROW_NUMBER" => {
610                // ROW_NUMBER() - no arguments
611                Ok(DataValue::Integer(context.get_row_number(row_index) as i64))
612            }
613            "FIRST_VALUE" => {
614                // FIRST_VALUE(column)
615                if args.is_empty() {
616                    return Err(anyhow!("FIRST_VALUE requires 1 argument"));
617                }
618
619                let column = match &args[0] {
620                    SqlExpression::Column(col) => col.clone(),
621                    _ => return Err(anyhow!("FIRST_VALUE argument must be a column")),
622                };
623
624                Ok(context
625                    .get_first_value(row_index, &column)
626                    .unwrap_or(DataValue::Null))
627            }
628            "LAST_VALUE" => {
629                // LAST_VALUE(column)
630                if args.is_empty() {
631                    return Err(anyhow!("LAST_VALUE requires 1 argument"));
632                }
633
634                let column = match &args[0] {
635                    SqlExpression::Column(col) => col.clone(),
636                    _ => return Err(anyhow!("LAST_VALUE argument must be a column")),
637                };
638
639                Ok(context
640                    .get_last_value(row_index, &column)
641                    .unwrap_or(DataValue::Null))
642            }
643            _ => Err(anyhow!("Unknown window function: {}", name)),
644        }
645    }
646
647    /// Evaluate a method call on a column (e.g., `column.Trim()`)
648    fn evaluate_method_call(
649        &mut self,
650        object: &str,
651        method: &str,
652        args: &[SqlExpression],
653        row_index: usize,
654    ) -> Result<DataValue> {
655        // Get column value
656        let col_index = self.table.get_column_index(object).ok_or_else(|| {
657            let suggestion = self.find_similar_column(object);
658            match suggestion {
659                Some(similar) => {
660                    anyhow!("Column '{}' not found. Did you mean '{}'?", object, similar)
661                }
662                None => anyhow!("Column '{}' not found", object),
663            }
664        })?;
665
666        let cell_value = self.table.get_value(row_index, col_index).cloned();
667
668        self.evaluate_method_on_value(
669            &cell_value.unwrap_or(DataValue::Null),
670            method,
671            args,
672            row_index,
673        )
674    }
675
676    /// Evaluate a method on a value
677    fn evaluate_method_on_value(
678        &mut self,
679        value: &DataValue,
680        method: &str,
681        args: &[SqlExpression],
682        row_index: usize,
683    ) -> Result<DataValue> {
684        // First, try to proxy the method through the function registry
685        // Many string methods have corresponding functions (TRIM, LENGTH, CONTAINS, etc.)
686
687        // Map method names to function names (case-insensitive matching)
688        let function_name = match method.to_lowercase().as_str() {
689            "trim" => "TRIM",
690            "trimstart" | "trimbegin" => "TRIMSTART",
691            "trimend" => "TRIMEND",
692            "length" | "len" => "LENGTH",
693            "contains" => "CONTAINS",
694            "startswith" => "STARTSWITH",
695            "endswith" => "ENDSWITH",
696            "indexof" => "INDEXOF",
697            _ => method, // Try the method name as-is
698        };
699
700        // Check if we have this function in the registry
701        if self.function_registry.get(function_name).is_some() {
702            debug!(
703                "Proxying method '{}' through function registry as '{}'",
704                method, function_name
705            );
706
707            // Prepare arguments: receiver is the first argument, followed by method args
708            let mut func_args = vec![value.clone()];
709
710            // Evaluate method arguments and add them
711            for arg in args {
712                func_args.push(self.evaluate(arg, row_index)?);
713            }
714
715            // Get the function and call it
716            let func = self.function_registry.get(function_name).unwrap();
717            return func.evaluate(&func_args);
718        }
719
720        // If not in registry, fall back to the old implementation for compatibility
721        match method.to_lowercase().as_str() {
722            "trim" | "trimstart" | "trimend" => {
723                if !args.is_empty() {
724                    return Err(anyhow!("{} takes no arguments", method));
725                }
726
727                // Convert value to string and apply trim
728                let str_val = match value {
729                    DataValue::String(s) => s.clone(),
730                    DataValue::InternedString(s) => s.to_string(),
731                    DataValue::Integer(n) => n.to_string(),
732                    DataValue::Float(f) => f.to_string(),
733                    DataValue::Boolean(b) => b.to_string(),
734                    DataValue::DateTime(dt) => dt.clone(),
735                    DataValue::Null => return Ok(DataValue::Null),
736                };
737
738                let result = match method.to_lowercase().as_str() {
739                    "trim" => str_val.trim().to_string(),
740                    "trimstart" => str_val.trim_start().to_string(),
741                    "trimend" => str_val.trim_end().to_string(),
742                    _ => unreachable!(),
743                };
744
745                Ok(DataValue::String(result))
746            }
747            "length" => {
748                if !args.is_empty() {
749                    return Err(anyhow!("Length takes no arguments"));
750                }
751
752                // Get string length
753                let len = match value {
754                    DataValue::String(s) => s.len(),
755                    DataValue::InternedString(s) => s.len(),
756                    DataValue::Integer(n) => n.to_string().len(),
757                    DataValue::Float(f) => f.to_string().len(),
758                    DataValue::Boolean(b) => b.to_string().len(),
759                    DataValue::DateTime(dt) => dt.len(),
760                    DataValue::Null => return Ok(DataValue::Integer(0)),
761                };
762
763                Ok(DataValue::Integer(len as i64))
764            }
765            "indexof" => {
766                if args.len() != 1 {
767                    return Err(anyhow!("IndexOf requires exactly 1 argument"));
768                }
769
770                // Get the search string from args
771                let search_str = match self.evaluate(&args[0], row_index)? {
772                    DataValue::String(s) => s,
773                    DataValue::InternedString(s) => s.to_string(),
774                    DataValue::Integer(n) => n.to_string(),
775                    DataValue::Float(f) => f.to_string(),
776                    _ => return Err(anyhow!("IndexOf argument must be a string")),
777                };
778
779                // Convert value to string and find index
780                let str_val = match value {
781                    DataValue::String(s) => s.clone(),
782                    DataValue::InternedString(s) => s.to_string(),
783                    DataValue::Integer(n) => n.to_string(),
784                    DataValue::Float(f) => f.to_string(),
785                    DataValue::Boolean(b) => b.to_string(),
786                    DataValue::DateTime(dt) => dt.clone(),
787                    DataValue::Null => return Ok(DataValue::Integer(-1)),
788                };
789
790                let index = str_val.find(&search_str).map_or(-1, |i| i as i64);
791
792                Ok(DataValue::Integer(index))
793            }
794            "contains" => {
795                if args.len() != 1 {
796                    return Err(anyhow!("Contains requires exactly 1 argument"));
797                }
798
799                // Get the search string from args
800                let search_str = match self.evaluate(&args[0], row_index)? {
801                    DataValue::String(s) => s,
802                    DataValue::InternedString(s) => s.to_string(),
803                    DataValue::Integer(n) => n.to_string(),
804                    DataValue::Float(f) => f.to_string(),
805                    _ => return Err(anyhow!("Contains argument must be a string")),
806                };
807
808                // Convert value to string and check contains
809                let str_val = match value {
810                    DataValue::String(s) => s.clone(),
811                    DataValue::InternedString(s) => s.to_string(),
812                    DataValue::Integer(n) => n.to_string(),
813                    DataValue::Float(f) => f.to_string(),
814                    DataValue::Boolean(b) => b.to_string(),
815                    DataValue::DateTime(dt) => dt.clone(),
816                    DataValue::Null => return Ok(DataValue::Boolean(false)),
817                };
818
819                // Case-insensitive search
820                let result = str_val.to_lowercase().contains(&search_str.to_lowercase());
821                Ok(DataValue::Boolean(result))
822            }
823            "startswith" => {
824                if args.len() != 1 {
825                    return Err(anyhow!("StartsWith requires exactly 1 argument"));
826                }
827
828                // Get the prefix from args
829                let prefix = match self.evaluate(&args[0], row_index)? {
830                    DataValue::String(s) => s,
831                    DataValue::InternedString(s) => s.to_string(),
832                    DataValue::Integer(n) => n.to_string(),
833                    DataValue::Float(f) => f.to_string(),
834                    _ => return Err(anyhow!("StartsWith argument must be a string")),
835                };
836
837                // Convert value to string and check starts_with
838                let str_val = match value {
839                    DataValue::String(s) => s.clone(),
840                    DataValue::InternedString(s) => s.to_string(),
841                    DataValue::Integer(n) => n.to_string(),
842                    DataValue::Float(f) => f.to_string(),
843                    DataValue::Boolean(b) => b.to_string(),
844                    DataValue::DateTime(dt) => dt.clone(),
845                    DataValue::Null => return Ok(DataValue::Boolean(false)),
846                };
847
848                // Case-insensitive check
849                let result = str_val.to_lowercase().starts_with(&prefix.to_lowercase());
850                Ok(DataValue::Boolean(result))
851            }
852            "endswith" => {
853                if args.len() != 1 {
854                    return Err(anyhow!("EndsWith requires exactly 1 argument"));
855                }
856
857                // Get the suffix from args
858                let suffix = match self.evaluate(&args[0], row_index)? {
859                    DataValue::String(s) => s,
860                    DataValue::InternedString(s) => s.to_string(),
861                    DataValue::Integer(n) => n.to_string(),
862                    DataValue::Float(f) => f.to_string(),
863                    _ => return Err(anyhow!("EndsWith argument must be a string")),
864                };
865
866                // Convert value to string and check ends_with
867                let str_val = match value {
868                    DataValue::String(s) => s.clone(),
869                    DataValue::InternedString(s) => s.to_string(),
870                    DataValue::Integer(n) => n.to_string(),
871                    DataValue::Float(f) => f.to_string(),
872                    DataValue::Boolean(b) => b.to_string(),
873                    DataValue::DateTime(dt) => dt.clone(),
874                    DataValue::Null => return Ok(DataValue::Boolean(false)),
875                };
876
877                // Case-insensitive check
878                let result = str_val.to_lowercase().ends_with(&suffix.to_lowercase());
879                Ok(DataValue::Boolean(result))
880            }
881            _ => Err(anyhow!("Unsupported method: {}", method)),
882        }
883    }
884
885    /// Evaluate a CASE expression
886    fn evaluate_case_expression(
887        &mut self,
888        when_branches: &[crate::sql::recursive_parser::WhenBranch],
889        else_branch: &Option<Box<SqlExpression>>,
890        row_index: usize,
891    ) -> Result<DataValue> {
892        debug!(
893            "ArithmeticEvaluator: evaluating CASE expression for row {}",
894            row_index
895        );
896
897        // Evaluate each WHEN condition in order
898        for branch in when_branches {
899            // Evaluate the condition as a boolean
900            let condition_result = self.evaluate_condition_as_bool(&branch.condition, row_index)?;
901
902            if condition_result {
903                debug!("CASE: WHEN condition matched, evaluating result expression");
904                return self.evaluate(&branch.result, row_index);
905            }
906        }
907
908        // If no WHEN condition matched, evaluate ELSE clause (or return NULL)
909        if let Some(else_expr) = else_branch {
910            debug!("CASE: No WHEN matched, evaluating ELSE expression");
911            self.evaluate(else_expr, row_index)
912        } else {
913            debug!("CASE: No WHEN matched and no ELSE, returning NULL");
914            Ok(DataValue::Null)
915        }
916    }
917
918    /// Helper method to evaluate an expression as a boolean (for CASE WHEN conditions)
919    fn evaluate_condition_as_bool(
920        &mut self,
921        expr: &SqlExpression,
922        row_index: usize,
923    ) -> Result<bool> {
924        let value = self.evaluate(expr, row_index)?;
925
926        match value {
927            DataValue::Boolean(b) => Ok(b),
928            DataValue::Integer(i) => Ok(i != 0),
929            DataValue::Float(f) => Ok(f != 0.0),
930            DataValue::Null => Ok(false),
931            DataValue::String(s) => Ok(!s.is_empty()),
932            DataValue::InternedString(s) => Ok(!s.is_empty()),
933            _ => Ok(true), // Other types are considered truthy
934        }
935    }
936}
937
938#[cfg(test)]
939mod tests {
940    use super::*;
941    use crate::data::datatable::{DataColumn, DataRow};
942
943    fn create_test_table() -> DataTable {
944        let mut table = DataTable::new("test");
945        table.add_column(DataColumn::new("a"));
946        table.add_column(DataColumn::new("b"));
947        table.add_column(DataColumn::new("c"));
948
949        table
950            .add_row(DataRow::new(vec![
951                DataValue::Integer(10),
952                DataValue::Float(2.5),
953                DataValue::Integer(4),
954            ]))
955            .unwrap();
956
957        table
958    }
959
960    #[test]
961    fn test_evaluate_column() {
962        let table = create_test_table();
963        let mut evaluator = ArithmeticEvaluator::new(&table);
964
965        let expr = SqlExpression::Column("a".to_string());
966        let result = evaluator.evaluate(&expr, 0).unwrap();
967        assert_eq!(result, DataValue::Integer(10));
968    }
969
970    #[test]
971    fn test_evaluate_number_literal() {
972        let table = create_test_table();
973        let mut evaluator = ArithmeticEvaluator::new(&table);
974
975        let expr = SqlExpression::NumberLiteral("42".to_string());
976        let result = evaluator.evaluate(&expr, 0).unwrap();
977        assert_eq!(result, DataValue::Integer(42));
978
979        let expr = SqlExpression::NumberLiteral("3.14".to_string());
980        let result = evaluator.evaluate(&expr, 0).unwrap();
981        assert_eq!(result, DataValue::Float(3.14));
982    }
983
984    #[test]
985    fn test_add_values() {
986        let table = create_test_table();
987        let mut evaluator = ArithmeticEvaluator::new(&table);
988
989        // Integer + Integer
990        let result = evaluator
991            .add_values(&DataValue::Integer(5), &DataValue::Integer(3))
992            .unwrap();
993        assert_eq!(result, DataValue::Integer(8));
994
995        // Integer + Float
996        let result = evaluator
997            .add_values(&DataValue::Integer(5), &DataValue::Float(2.5))
998            .unwrap();
999        assert_eq!(result, DataValue::Float(7.5));
1000    }
1001
1002    #[test]
1003    fn test_multiply_values() {
1004        let table = create_test_table();
1005        let mut evaluator = ArithmeticEvaluator::new(&table);
1006
1007        // Integer * Float
1008        let result = evaluator
1009            .multiply_values(&DataValue::Integer(4), &DataValue::Float(2.5))
1010            .unwrap();
1011        assert_eq!(result, DataValue::Float(10.0));
1012    }
1013
1014    #[test]
1015    fn test_divide_values() {
1016        let table = create_test_table();
1017        let mut evaluator = ArithmeticEvaluator::new(&table);
1018
1019        // Exact division
1020        let result = evaluator
1021            .divide_values(&DataValue::Integer(10), &DataValue::Integer(2))
1022            .unwrap();
1023        assert_eq!(result, DataValue::Integer(5));
1024
1025        // Non-exact division
1026        let result = evaluator
1027            .divide_values(&DataValue::Integer(10), &DataValue::Integer(3))
1028            .unwrap();
1029        assert_eq!(result, DataValue::Float(10.0 / 3.0));
1030    }
1031
1032    #[test]
1033    fn test_division_by_zero() {
1034        let table = create_test_table();
1035        let mut evaluator = ArithmeticEvaluator::new(&table);
1036
1037        let result = evaluator.divide_values(&DataValue::Integer(10), &DataValue::Integer(0));
1038        assert!(result.is_err());
1039        assert!(result.unwrap_err().to_string().contains("Division by zero"));
1040    }
1041
1042    #[test]
1043    fn test_binary_op_expression() {
1044        let table = create_test_table();
1045        let mut evaluator = ArithmeticEvaluator::new(&table);
1046
1047        // a * b where a=10, b=2.5
1048        let expr = SqlExpression::BinaryOp {
1049            left: Box::new(SqlExpression::Column("a".to_string())),
1050            op: "*".to_string(),
1051            right: Box::new(SqlExpression::Column("b".to_string())),
1052        };
1053
1054        let result = evaluator.evaluate(&expr, 0).unwrap();
1055        assert_eq!(result, DataValue::Float(25.0));
1056    }
1057}