sql_cli/data/
arithmetic_evaluator.rs

1use crate::data::data_view::DataView;
2use crate::data::datatable::{DataTable, DataValue};
3use crate::sql::aggregates::AggregateRegistry;
4use crate::sql::functions::FunctionRegistry;
5use crate::sql::recursive_parser::{SqlExpression, WindowSpec};
6use crate::sql::window_context::WindowContext;
7use anyhow::{anyhow, Result};
8use std::collections::HashMap;
9use std::sync::Arc;
10use tracing::debug;
11
12/// Evaluates SQL expressions to compute `DataValues` (for SELECT clauses)
13/// This is different from `RecursiveWhereEvaluator` which returns boolean
14pub struct ArithmeticEvaluator<'a> {
15    table: &'a DataTable,
16    date_notation: String,
17    function_registry: Arc<FunctionRegistry>,
18    aggregate_registry: Arc<AggregateRegistry>,
19    visible_rows: Option<Vec<usize>>, // For aggregate functions on filtered views
20    window_contexts: HashMap<String, Arc<WindowContext>>, // Cache window contexts by spec
21}
22
23impl<'a> ArithmeticEvaluator<'a> {
24    #[must_use]
25    pub fn new(table: &'a DataTable) -> Self {
26        Self {
27            table,
28            date_notation: "us".to_string(),
29            function_registry: Arc::new(FunctionRegistry::new()),
30            aggregate_registry: Arc::new(AggregateRegistry::new()),
31            visible_rows: None,
32            window_contexts: HashMap::new(),
33        }
34    }
35
36    #[must_use]
37    pub fn with_date_notation(table: &'a DataTable, date_notation: String) -> Self {
38        Self {
39            table,
40            date_notation,
41            function_registry: Arc::new(FunctionRegistry::new()),
42            aggregate_registry: Arc::new(AggregateRegistry::new()),
43            visible_rows: None,
44            window_contexts: HashMap::new(),
45        }
46    }
47
48    /// Set visible rows for aggregate functions (for filtered views)
49    #[must_use]
50    pub fn with_visible_rows(mut self, rows: Vec<usize>) -> Self {
51        self.visible_rows = Some(rows);
52        self
53    }
54
55    #[must_use]
56    pub fn with_date_notation_and_registry(
57        table: &'a DataTable,
58        date_notation: String,
59        function_registry: Arc<FunctionRegistry>,
60    ) -> Self {
61        Self {
62            table,
63            date_notation,
64            function_registry,
65            aggregate_registry: Arc::new(AggregateRegistry::new()),
66            visible_rows: None,
67            window_contexts: HashMap::new(),
68        }
69    }
70
71    /// Find a column name similar to the given name using edit distance
72    fn find_similar_column(&self, name: &str) -> Option<String> {
73        let columns = self.table.column_names();
74        let mut best_match: Option<(String, usize)> = None;
75
76        for col in columns {
77            let distance = self.edit_distance(&col.to_lowercase(), &name.to_lowercase());
78            // Only suggest if distance is small (likely a typo)
79            // Allow up to 3 edits for longer names
80            let max_distance = if name.len() > 10 { 3 } else { 2 };
81            if distance <= max_distance {
82                match &best_match {
83                    None => best_match = Some((col, distance)),
84                    Some((_, best_dist)) if distance < *best_dist => {
85                        best_match = Some((col, distance));
86                    }
87                    _ => {}
88                }
89            }
90        }
91
92        best_match.map(|(name, _)| name)
93    }
94
95    /// Calculate Levenshtein edit distance between two strings
96    fn edit_distance(&self, s1: &str, s2: &str) -> usize {
97        // Use the shared implementation from string_methods
98        crate::sql::functions::string_methods::EditDistanceFunction::calculate_edit_distance(s1, s2)
99    }
100
101    /// Evaluate an SQL expression to produce a `DataValue`
102    pub fn evaluate(&mut self, expr: &SqlExpression, row_index: usize) -> Result<DataValue> {
103        debug!(
104            "ArithmeticEvaluator: evaluating {:?} for row {}",
105            expr, row_index
106        );
107
108        match expr {
109            SqlExpression::Column(column_name) => self.evaluate_column(column_name, row_index),
110            SqlExpression::StringLiteral(s) => Ok(DataValue::String(s.clone())),
111            SqlExpression::BooleanLiteral(b) => Ok(DataValue::Boolean(*b)),
112            SqlExpression::NumberLiteral(n) => self.evaluate_number_literal(n),
113            SqlExpression::Null => Ok(DataValue::Null),
114            SqlExpression::BinaryOp { left, op, right } => {
115                self.evaluate_binary_op(left, op, right, row_index)
116            }
117            SqlExpression::FunctionCall { name, args } => {
118                self.evaluate_function(name, args, row_index)
119            }
120            SqlExpression::WindowFunction {
121                name,
122                args,
123                window_spec,
124            } => self.evaluate_window_function(name, args, window_spec, row_index),
125            SqlExpression::MethodCall {
126                object,
127                method,
128                args,
129            } => self.evaluate_method_call(object, method, args, row_index),
130            SqlExpression::ChainedMethodCall { base, method, args } => {
131                // Evaluate the base expression first, then apply the method
132                let base_value = self.evaluate(base, row_index)?;
133                self.evaluate_method_on_value(&base_value, method, args, row_index)
134            }
135            SqlExpression::CaseExpression {
136                when_branches,
137                else_branch,
138            } => self.evaluate_case_expression(when_branches, else_branch, row_index),
139            _ => Err(anyhow!(
140                "Unsupported expression type for arithmetic evaluation: {:?}",
141                expr
142            )),
143        }
144    }
145
146    /// Evaluate a column reference
147    fn evaluate_column(&self, column_name: &str, row_index: usize) -> Result<DataValue> {
148        let col_index = self.table.get_column_index(column_name).ok_or_else(|| {
149            let suggestion = self.find_similar_column(column_name);
150            match suggestion {
151                Some(similar) => anyhow!(
152                    "Column '{}' not found. Did you mean '{}'?",
153                    column_name,
154                    similar
155                ),
156                None => anyhow!("Column '{}' not found", column_name),
157            }
158        })?;
159
160        if row_index >= self.table.row_count() {
161            return Err(anyhow!("Row index {} out of bounds", row_index));
162        }
163
164        let row = self
165            .table
166            .get_row(row_index)
167            .ok_or_else(|| anyhow!("Row {} not found", row_index))?;
168
169        let value = row
170            .get(col_index)
171            .ok_or_else(|| anyhow!("Column index {} out of bounds for row", col_index))?;
172
173        Ok(value.clone())
174    }
175
176    /// Evaluate a number literal (handles both integers and floats)
177    fn evaluate_number_literal(&self, number_str: &str) -> Result<DataValue> {
178        // Try to parse as integer first
179        if let Ok(int_val) = number_str.parse::<i64>() {
180            return Ok(DataValue::Integer(int_val));
181        }
182
183        // If that fails, try as float
184        if let Ok(float_val) = number_str.parse::<f64>() {
185            return Ok(DataValue::Float(float_val));
186        }
187
188        Err(anyhow!("Invalid number literal: {}", number_str))
189    }
190
191    /// Evaluate a binary operation (arithmetic)
192    fn evaluate_binary_op(
193        &mut self,
194        left: &SqlExpression,
195        op: &str,
196        right: &SqlExpression,
197        row_index: usize,
198    ) -> Result<DataValue> {
199        let left_val = self.evaluate(left, row_index)?;
200        let right_val = self.evaluate(right, row_index)?;
201
202        debug!(
203            "ArithmeticEvaluator: {} {} {}",
204            self.format_value(&left_val),
205            op,
206            self.format_value(&right_val)
207        );
208
209        match op {
210            "+" => self.add_values(&left_val, &right_val),
211            "-" => self.subtract_values(&left_val, &right_val),
212            "*" => self.multiply_values(&left_val, &right_val),
213            "/" => self.divide_values(&left_val, &right_val),
214            "%" => {
215                // Modulo operator - call MOD function
216                let args = vec![left.clone(), right.clone()];
217                self.evaluate_function("MOD", &args, row_index)
218            }
219            // Comparison operators (return boolean results)
220            ">" => self.compare_values(&left_val, &right_val, |a, b| a > b),
221            "<" => self.compare_values(&left_val, &right_val, |a, b| a < b),
222            ">=" => self.compare_values(&left_val, &right_val, |a, b| a >= b),
223            "<=" => self.compare_values(&left_val, &right_val, |a, b| a <= b),
224            "=" => self.compare_values(&left_val, &right_val, |a, b| a == b),
225            "!=" | "<>" => self.compare_values(&left_val, &right_val, |a, b| a != b),
226            // IS NULL / IS NOT NULL operators
227            "IS NULL" => Ok(DataValue::Boolean(matches!(left_val, DataValue::Null))),
228            "IS NOT NULL" => Ok(DataValue::Boolean(!matches!(left_val, DataValue::Null))),
229            _ => Err(anyhow!("Unsupported arithmetic operator: {}", op)),
230        }
231    }
232
233    /// Add two `DataValues` with type coercion
234    fn add_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
235        // NULL handling - any operation with NULL returns NULL
236        if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
237            return Ok(DataValue::Null);
238        }
239
240        match (left, right) {
241            (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a + b)),
242            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 + b)),
243            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a + *b as f64)),
244            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a + b)),
245            _ => Err(anyhow!("Cannot add {:?} and {:?}", left, right)),
246        }
247    }
248
249    /// Subtract two `DataValues` with type coercion
250    fn subtract_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
251        // NULL handling - any operation with NULL returns NULL
252        if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
253            return Ok(DataValue::Null);
254        }
255
256        match (left, right) {
257            (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a - b)),
258            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 - b)),
259            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a - *b as f64)),
260            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a - b)),
261            _ => Err(anyhow!("Cannot subtract {:?} and {:?}", left, right)),
262        }
263    }
264
265    /// Multiply two `DataValues` with type coercion
266    fn multiply_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
267        // NULL handling - any operation with NULL returns NULL
268        if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
269            return Ok(DataValue::Null);
270        }
271
272        match (left, right) {
273            (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a * b)),
274            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 * b)),
275            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a * *b as f64)),
276            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a * b)),
277            _ => Err(anyhow!("Cannot multiply {:?} and {:?}", left, right)),
278        }
279    }
280
281    /// Divide two `DataValues` with type coercion
282    fn divide_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
283        // NULL handling - any operation with NULL returns NULL
284        if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
285            return Ok(DataValue::Null);
286        }
287
288        // Check for division by zero first
289        let is_zero = match right {
290            DataValue::Integer(0) => true,
291            DataValue::Float(f) if *f == 0.0 => true, // Only check for exact zero, not epsilon
292            _ => false,
293        };
294
295        if is_zero {
296            return Err(anyhow!("Division by zero"));
297        }
298
299        match (left, right) {
300            (DataValue::Integer(a), DataValue::Integer(b)) => {
301                // Integer division - if result is exact, keep as int, otherwise promote to float
302                if a % b == 0 {
303                    Ok(DataValue::Integer(a / b))
304                } else {
305                    Ok(DataValue::Float(*a as f64 / *b as f64))
306                }
307            }
308            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 / b)),
309            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a / *b as f64)),
310            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a / b)),
311            _ => Err(anyhow!("Cannot divide {:?} and {:?}", left, right)),
312        }
313    }
314
315    /// Format a `DataValue` for debug output
316    fn format_value(&self, value: &DataValue) -> String {
317        match value {
318            DataValue::Integer(i) => i.to_string(),
319            DataValue::Float(f) => f.to_string(),
320            DataValue::String(s) => format!("'{s}'"),
321            _ => format!("{value:?}"),
322        }
323    }
324
325    /// Compare two `DataValues` using the provided comparison function
326    fn compare_values<F>(&self, left: &DataValue, right: &DataValue, op: F) -> Result<DataValue>
327    where
328        F: Fn(f64, f64) -> bool,
329    {
330        debug!(
331            "ArithmeticEvaluator: comparing values {:?} and {:?}",
332            left, right
333        );
334
335        let result = match (left, right) {
336            // Integer comparisons
337            (DataValue::Integer(a), DataValue::Integer(b)) => op(*a as f64, *b as f64),
338            (DataValue::Integer(a), DataValue::Float(b)) => op(*a as f64, *b),
339            (DataValue::Float(a), DataValue::Integer(b)) => op(*a, *b as f64),
340            (DataValue::Float(a), DataValue::Float(b)) => op(*a, *b),
341
342            // String comparisons (lexicographic)
343            (DataValue::String(a), DataValue::String(b)) => {
344                let a_num = a.parse::<f64>();
345                let b_num = b.parse::<f64>();
346                match (a_num, b_num) {
347                    (Ok(a_val), Ok(b_val)) => op(a_val, b_val), // Both are numbers
348                    _ => op(a.len() as f64, b.len() as f64),    // Fallback to length comparison
349                }
350            }
351            (DataValue::InternedString(a), DataValue::InternedString(b)) => {
352                let a_num = a.parse::<f64>();
353                let b_num = b.parse::<f64>();
354                match (a_num, b_num) {
355                    (Ok(a_val), Ok(b_val)) => op(a_val, b_val), // Both are numbers
356                    _ => op(a.len() as f64, b.len() as f64),    // Fallback to length comparison
357                }
358            }
359            (DataValue::String(a), DataValue::InternedString(b)) => {
360                let a_num = a.parse::<f64>();
361                let b_num = b.parse::<f64>();
362                match (a_num, b_num) {
363                    (Ok(a_val), Ok(b_val)) => op(a_val, b_val), // Both are numbers
364                    _ => op(a.len() as f64, b.len() as f64),    // Fallback to length comparison
365                }
366            }
367            (DataValue::InternedString(a), DataValue::String(b)) => {
368                let a_num = a.parse::<f64>();
369                let b_num = b.parse::<f64>();
370                match (a_num, b_num) {
371                    (Ok(a_val), Ok(b_val)) => op(a_val, b_val), // Both are numbers
372                    _ => op(a.len() as f64, b.len() as f64),    // Fallback to length comparison
373                }
374            }
375
376            // Mixed type comparisons (try to convert to numbers)
377            (DataValue::String(a), DataValue::Integer(b)) => {
378                match a.parse::<f64>() {
379                    Ok(a_val) => op(a_val, *b as f64),
380                    Err(_) => false, // String can't be compared with number
381                }
382            }
383            (DataValue::Integer(a), DataValue::String(b)) => {
384                match b.parse::<f64>() {
385                    Ok(b_val) => op(*a as f64, b_val),
386                    Err(_) => false, // String can't be compared with number
387                }
388            }
389            (DataValue::String(a), DataValue::Float(b)) => match a.parse::<f64>() {
390                Ok(a_val) => op(a_val, *b),
391                Err(_) => false,
392            },
393            (DataValue::Float(a), DataValue::String(b)) => match b.parse::<f64>() {
394                Ok(b_val) => op(*a, b_val),
395                Err(_) => false,
396            },
397
398            // NULL comparisons
399            (DataValue::Null, _) | (_, DataValue::Null) => false,
400
401            // Boolean comparisons
402            (DataValue::Boolean(a), DataValue::Boolean(b)) => {
403                op(if *a { 1.0 } else { 0.0 }, if *b { 1.0 } else { 0.0 })
404            }
405
406            _ => {
407                debug!(
408                    "ArithmeticEvaluator: unsupported comparison between {:?} and {:?}",
409                    left, right
410                );
411                false
412            }
413        };
414
415        debug!("ArithmeticEvaluator: comparison result: {}", result);
416        Ok(DataValue::Boolean(result))
417    }
418
419    /// Evaluate a function call
420    fn evaluate_function(
421        &mut self,
422        name: &str,
423        args: &[SqlExpression],
424        row_index: usize,
425    ) -> Result<DataValue> {
426        // Check if this is an aggregate function
427        let name_upper = name.to_uppercase();
428
429        // Handle COUNT(*) special case
430        if name_upper == "COUNT" && args.len() == 1 {
431            match &args[0] {
432                SqlExpression::Column(col) if col == "*" => {
433                    // COUNT(*) - count all rows (or visible rows if filtered)
434                    let count = if let Some(ref visible) = self.visible_rows {
435                        visible.len() as i64
436                    } else {
437                        self.table.rows.len() as i64
438                    };
439                    return Ok(DataValue::Integer(count));
440                }
441                SqlExpression::StringLiteral(s) if s == "*" => {
442                    // COUNT(*) parsed as StringLiteral
443                    let count = if let Some(ref visible) = self.visible_rows {
444                        visible.len() as i64
445                    } else {
446                        self.table.rows.len() as i64
447                    };
448                    return Ok(DataValue::Integer(count));
449                }
450                _ => {
451                    // COUNT(column) - will be handled below
452                }
453            }
454        }
455
456        // Check aggregate registry
457        if self.aggregate_registry.get(&name_upper).is_some() {
458            // Determine which rows to process first
459            let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
460                visible.clone()
461            } else {
462                (0..self.table.rows.len()).collect()
463            };
464
465            // Evaluate arguments first if needed (to avoid borrow issues)
466            let values = if !args.is_empty()
467                && !(args.len() == 1 && matches!(&args[0], SqlExpression::Column(c) if c == "*"))
468            {
469                // Evaluate the argument expression for each row
470                let mut vals = Vec::new();
471                for &row_idx in &rows_to_process {
472                    let value = self.evaluate(&args[0], row_idx)?;
473                    vals.push(value);
474                }
475                Some(vals)
476            } else {
477                None
478            };
479
480            // Now get the aggregate function and process
481            let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
482            let mut state = agg_func.init();
483
484            if let Some(values) = values {
485                // Use evaluated values
486                for value in &values {
487                    agg_func.accumulate(&mut state, value)?;
488                }
489            } else {
490                // COUNT(*) case
491                for _ in &rows_to_process {
492                    agg_func.accumulate(&mut state, &DataValue::Integer(1))?;
493                }
494            }
495
496            return Ok(agg_func.finalize(state));
497        }
498
499        // First check if this function exists in the registry
500        if self.function_registry.get(name).is_some() {
501            // Evaluate all arguments first to avoid borrow issues
502            let mut evaluated_args = Vec::new();
503            for arg in args {
504                evaluated_args.push(self.evaluate(arg, row_index)?);
505            }
506
507            // Get the function and call it
508            let func = self.function_registry.get(name).unwrap();
509            return func.evaluate(&evaluated_args);
510        }
511
512        // If not in registry, return error for unknown function
513        Err(anyhow!("Unknown function: {}", name))
514    }
515
516    /// Get or create a WindowContext for the given specification
517    fn get_or_create_window_context(&mut self, spec: &WindowSpec) -> Result<Arc<WindowContext>> {
518        // Create a key for caching based on the spec
519        let key = format!("{:?}", spec);
520
521        if let Some(context) = self.window_contexts.get(&key) {
522            return Ok(Arc::clone(context));
523        }
524
525        // Create a DataView from the table (with visible rows if filtered)
526        let data_view = if let Some(ref visible_rows) = self.visible_rows {
527            // Create a filtered view
528            let mut view = DataView::new(Arc::new(self.table.clone()));
529            // Apply filtering based on visible rows
530            // Note: This is a simplified approach - in production we'd need proper filtering
531            view
532        } else {
533            DataView::new(Arc::new(self.table.clone()))
534        };
535
536        // Create the WindowContext
537        let context = WindowContext::new(
538            Arc::new(data_view),
539            spec.partition_by.clone(),
540            spec.order_by.clone(),
541        )?;
542
543        let context = Arc::new(context);
544        self.window_contexts.insert(key, Arc::clone(&context));
545        Ok(context)
546    }
547
548    /// Evaluate a window function
549    fn evaluate_window_function(
550        &mut self,
551        name: &str,
552        args: &[SqlExpression],
553        spec: &WindowSpec,
554        row_index: usize,
555    ) -> Result<DataValue> {
556        let context = self.get_or_create_window_context(spec)?;
557        let name_upper = name.to_uppercase();
558
559        match name_upper.as_str() {
560            "LAG" => {
561                // LAG(column, offset, default)
562                if args.is_empty() {
563                    return Err(anyhow!("LAG requires at least 1 argument"));
564                }
565
566                // Get column name
567                let column = match &args[0] {
568                    SqlExpression::Column(col) => col.clone(),
569                    _ => return Err(anyhow!("LAG first argument must be a column")),
570                };
571
572                // Get offset (default 1)
573                let offset = if args.len() > 1 {
574                    match self.evaluate(&args[1], row_index)? {
575                        DataValue::Integer(i) => i as i32,
576                        _ => return Err(anyhow!("LAG offset must be an integer")),
577                    }
578                } else {
579                    1
580                };
581
582                // Get value at offset
583                Ok(context
584                    .get_offset_value(row_index, -offset, &column)
585                    .unwrap_or(DataValue::Null))
586            }
587            "LEAD" => {
588                // LEAD(column, offset, default)
589                if args.is_empty() {
590                    return Err(anyhow!("LEAD requires at least 1 argument"));
591                }
592
593                // Get column name
594                let column = match &args[0] {
595                    SqlExpression::Column(col) => col.clone(),
596                    _ => return Err(anyhow!("LEAD first argument must be a column")),
597                };
598
599                // Get offset (default 1)
600                let offset = if args.len() > 1 {
601                    match self.evaluate(&args[1], row_index)? {
602                        DataValue::Integer(i) => i as i32,
603                        _ => return Err(anyhow!("LEAD offset must be an integer")),
604                    }
605                } else {
606                    1
607                };
608
609                // Get value at offset
610                Ok(context
611                    .get_offset_value(row_index, offset, &column)
612                    .unwrap_or(DataValue::Null))
613            }
614            "ROW_NUMBER" => {
615                // ROW_NUMBER() - no arguments
616                Ok(DataValue::Integer(context.get_row_number(row_index) as i64))
617            }
618            "FIRST_VALUE" => {
619                // FIRST_VALUE(column)
620                if args.is_empty() {
621                    return Err(anyhow!("FIRST_VALUE requires 1 argument"));
622                }
623
624                let column = match &args[0] {
625                    SqlExpression::Column(col) => col.clone(),
626                    _ => return Err(anyhow!("FIRST_VALUE argument must be a column")),
627                };
628
629                Ok(context
630                    .get_first_value(row_index, &column)
631                    .unwrap_or(DataValue::Null))
632            }
633            "LAST_VALUE" => {
634                // LAST_VALUE(column)
635                if args.is_empty() {
636                    return Err(anyhow!("LAST_VALUE requires 1 argument"));
637                }
638
639                let column = match &args[0] {
640                    SqlExpression::Column(col) => col.clone(),
641                    _ => return Err(anyhow!("LAST_VALUE argument must be a column")),
642                };
643
644                Ok(context
645                    .get_last_value(row_index, &column)
646                    .unwrap_or(DataValue::Null))
647            }
648            "SUM" => {
649                // SUM(column) OVER (PARTITION BY ...)
650                if args.is_empty() {
651                    return Err(anyhow!("SUM requires 1 argument"));
652                }
653
654                let column = match &args[0] {
655                    SqlExpression::Column(col) => col.clone(),
656                    _ => return Err(anyhow!("SUM argument must be a column")),
657                };
658
659                Ok(context
660                    .get_partition_sum(row_index, &column)
661                    .unwrap_or(DataValue::Null))
662            }
663            "COUNT" => {
664                // COUNT(*) or COUNT(column) OVER (PARTITION BY ...)
665                // Note: In window functions, COUNT(*) seems to come with no args
666                if args.is_empty() {
667                    // COUNT(*) OVER (...) - count all rows in partition
668                    Ok(context
669                        .get_partition_count(row_index, None)
670                        .unwrap_or(DataValue::Null))
671                } else {
672                    // Check for COUNT(*)
673                    let column = match &args[0] {
674                        SqlExpression::Column(col) => {
675                            if col == "*" {
676                                // COUNT(*) - count all rows in partition
677                                return Ok(context
678                                    .get_partition_count(row_index, None)
679                                    .unwrap_or(DataValue::Null));
680                            }
681                            col.clone()
682                        }
683                        SqlExpression::StringLiteral(s) if s == "*" => {
684                            // COUNT(*) as StringLiteral (how parser sends it)
685                            return Ok(context
686                                .get_partition_count(row_index, None)
687                                .unwrap_or(DataValue::Null));
688                        }
689                        _ => return Err(anyhow!("COUNT argument must be a column or *")),
690                    };
691
692                    // COUNT(column) - count non-null values
693                    Ok(context
694                        .get_partition_count(row_index, Some(&column))
695                        .unwrap_or(DataValue::Null))
696                }
697            }
698            _ => Err(anyhow!("Unknown window function: {}", name)),
699        }
700    }
701
702    /// Evaluate a method call on a column (e.g., `column.Trim()`)
703    fn evaluate_method_call(
704        &mut self,
705        object: &str,
706        method: &str,
707        args: &[SqlExpression],
708        row_index: usize,
709    ) -> Result<DataValue> {
710        // Get column value
711        let col_index = self.table.get_column_index(object).ok_or_else(|| {
712            let suggestion = self.find_similar_column(object);
713            match suggestion {
714                Some(similar) => {
715                    anyhow!("Column '{}' not found. Did you mean '{}'?", object, similar)
716                }
717                None => anyhow!("Column '{}' not found", object),
718            }
719        })?;
720
721        let cell_value = self.table.get_value(row_index, col_index).cloned();
722
723        self.evaluate_method_on_value(
724            &cell_value.unwrap_or(DataValue::Null),
725            method,
726            args,
727            row_index,
728        )
729    }
730
731    /// Evaluate a method on a value
732    fn evaluate_method_on_value(
733        &mut self,
734        value: &DataValue,
735        method: &str,
736        args: &[SqlExpression],
737        row_index: usize,
738    ) -> Result<DataValue> {
739        // First, try to proxy the method through the function registry
740        // Many string methods have corresponding functions (TRIM, LENGTH, CONTAINS, etc.)
741
742        // Map method names to function names (case-insensitive matching)
743        let function_name = match method.to_lowercase().as_str() {
744            "trim" => "TRIM",
745            "trimstart" | "trimbegin" => "TRIMSTART",
746            "trimend" => "TRIMEND",
747            "length" | "len" => "LENGTH",
748            "contains" => "CONTAINS",
749            "startswith" => "STARTSWITH",
750            "endswith" => "ENDSWITH",
751            "indexof" => "INDEXOF",
752            _ => method, // Try the method name as-is
753        };
754
755        // Check if we have this function in the registry
756        if self.function_registry.get(function_name).is_some() {
757            debug!(
758                "Proxying method '{}' through function registry as '{}'",
759                method, function_name
760            );
761
762            // Prepare arguments: receiver is the first argument, followed by method args
763            let mut func_args = vec![value.clone()];
764
765            // Evaluate method arguments and add them
766            for arg in args {
767                func_args.push(self.evaluate(arg, row_index)?);
768            }
769
770            // Get the function and call it
771            let func = self.function_registry.get(function_name).unwrap();
772            return func.evaluate(&func_args);
773        }
774
775        // If not in registry, fall back to the old implementation for compatibility
776        match method.to_lowercase().as_str() {
777            "trim" | "trimstart" | "trimend" => {
778                if !args.is_empty() {
779                    return Err(anyhow!("{} takes no arguments", method));
780                }
781
782                // Convert value to string and apply trim
783                let str_val = match value {
784                    DataValue::String(s) => s.clone(),
785                    DataValue::InternedString(s) => s.to_string(),
786                    DataValue::Integer(n) => n.to_string(),
787                    DataValue::Float(f) => f.to_string(),
788                    DataValue::Boolean(b) => b.to_string(),
789                    DataValue::DateTime(dt) => dt.clone(),
790                    DataValue::Null => return Ok(DataValue::Null),
791                };
792
793                let result = match method.to_lowercase().as_str() {
794                    "trim" => str_val.trim().to_string(),
795                    "trimstart" => str_val.trim_start().to_string(),
796                    "trimend" => str_val.trim_end().to_string(),
797                    _ => unreachable!(),
798                };
799
800                Ok(DataValue::String(result))
801            }
802            "length" => {
803                if !args.is_empty() {
804                    return Err(anyhow!("Length takes no arguments"));
805                }
806
807                // Get string length
808                let len = match value {
809                    DataValue::String(s) => s.len(),
810                    DataValue::InternedString(s) => s.len(),
811                    DataValue::Integer(n) => n.to_string().len(),
812                    DataValue::Float(f) => f.to_string().len(),
813                    DataValue::Boolean(b) => b.to_string().len(),
814                    DataValue::DateTime(dt) => dt.len(),
815                    DataValue::Null => return Ok(DataValue::Integer(0)),
816                };
817
818                Ok(DataValue::Integer(len as i64))
819            }
820            "indexof" => {
821                if args.len() != 1 {
822                    return Err(anyhow!("IndexOf requires exactly 1 argument"));
823                }
824
825                // Get the search string from args
826                let search_str = match self.evaluate(&args[0], row_index)? {
827                    DataValue::String(s) => s,
828                    DataValue::InternedString(s) => s.to_string(),
829                    DataValue::Integer(n) => n.to_string(),
830                    DataValue::Float(f) => f.to_string(),
831                    _ => return Err(anyhow!("IndexOf argument must be a string")),
832                };
833
834                // Convert value to string and find index
835                let str_val = match value {
836                    DataValue::String(s) => s.clone(),
837                    DataValue::InternedString(s) => s.to_string(),
838                    DataValue::Integer(n) => n.to_string(),
839                    DataValue::Float(f) => f.to_string(),
840                    DataValue::Boolean(b) => b.to_string(),
841                    DataValue::DateTime(dt) => dt.clone(),
842                    DataValue::Null => return Ok(DataValue::Integer(-1)),
843                };
844
845                let index = str_val.find(&search_str).map_or(-1, |i| i as i64);
846
847                Ok(DataValue::Integer(index))
848            }
849            "contains" => {
850                if args.len() != 1 {
851                    return Err(anyhow!("Contains requires exactly 1 argument"));
852                }
853
854                // Get the search string from args
855                let search_str = match self.evaluate(&args[0], row_index)? {
856                    DataValue::String(s) => s,
857                    DataValue::InternedString(s) => s.to_string(),
858                    DataValue::Integer(n) => n.to_string(),
859                    DataValue::Float(f) => f.to_string(),
860                    _ => return Err(anyhow!("Contains argument must be a string")),
861                };
862
863                // Convert value to string and check contains
864                let str_val = match value {
865                    DataValue::String(s) => s.clone(),
866                    DataValue::InternedString(s) => s.to_string(),
867                    DataValue::Integer(n) => n.to_string(),
868                    DataValue::Float(f) => f.to_string(),
869                    DataValue::Boolean(b) => b.to_string(),
870                    DataValue::DateTime(dt) => dt.clone(),
871                    DataValue::Null => return Ok(DataValue::Boolean(false)),
872                };
873
874                // Case-insensitive search
875                let result = str_val.to_lowercase().contains(&search_str.to_lowercase());
876                Ok(DataValue::Boolean(result))
877            }
878            "startswith" => {
879                if args.len() != 1 {
880                    return Err(anyhow!("StartsWith requires exactly 1 argument"));
881                }
882
883                // Get the prefix from args
884                let prefix = match self.evaluate(&args[0], row_index)? {
885                    DataValue::String(s) => s,
886                    DataValue::InternedString(s) => s.to_string(),
887                    DataValue::Integer(n) => n.to_string(),
888                    DataValue::Float(f) => f.to_string(),
889                    _ => return Err(anyhow!("StartsWith argument must be a string")),
890                };
891
892                // Convert value to string and check starts_with
893                let str_val = match value {
894                    DataValue::String(s) => s.clone(),
895                    DataValue::InternedString(s) => s.to_string(),
896                    DataValue::Integer(n) => n.to_string(),
897                    DataValue::Float(f) => f.to_string(),
898                    DataValue::Boolean(b) => b.to_string(),
899                    DataValue::DateTime(dt) => dt.clone(),
900                    DataValue::Null => return Ok(DataValue::Boolean(false)),
901                };
902
903                // Case-insensitive check
904                let result = str_val.to_lowercase().starts_with(&prefix.to_lowercase());
905                Ok(DataValue::Boolean(result))
906            }
907            "endswith" => {
908                if args.len() != 1 {
909                    return Err(anyhow!("EndsWith requires exactly 1 argument"));
910                }
911
912                // Get the suffix from args
913                let suffix = match self.evaluate(&args[0], row_index)? {
914                    DataValue::String(s) => s,
915                    DataValue::InternedString(s) => s.to_string(),
916                    DataValue::Integer(n) => n.to_string(),
917                    DataValue::Float(f) => f.to_string(),
918                    _ => return Err(anyhow!("EndsWith argument must be a string")),
919                };
920
921                // Convert value to string and check ends_with
922                let str_val = match value {
923                    DataValue::String(s) => s.clone(),
924                    DataValue::InternedString(s) => s.to_string(),
925                    DataValue::Integer(n) => n.to_string(),
926                    DataValue::Float(f) => f.to_string(),
927                    DataValue::Boolean(b) => b.to_string(),
928                    DataValue::DateTime(dt) => dt.clone(),
929                    DataValue::Null => return Ok(DataValue::Boolean(false)),
930                };
931
932                // Case-insensitive check
933                let result = str_val.to_lowercase().ends_with(&suffix.to_lowercase());
934                Ok(DataValue::Boolean(result))
935            }
936            _ => Err(anyhow!("Unsupported method: {}", method)),
937        }
938    }
939
940    /// Evaluate a CASE expression
941    fn evaluate_case_expression(
942        &mut self,
943        when_branches: &[crate::sql::recursive_parser::WhenBranch],
944        else_branch: &Option<Box<SqlExpression>>,
945        row_index: usize,
946    ) -> Result<DataValue> {
947        debug!(
948            "ArithmeticEvaluator: evaluating CASE expression for row {}",
949            row_index
950        );
951
952        // Evaluate each WHEN condition in order
953        for branch in when_branches {
954            // Evaluate the condition as a boolean
955            let condition_result = self.evaluate_condition_as_bool(&branch.condition, row_index)?;
956
957            if condition_result {
958                debug!("CASE: WHEN condition matched, evaluating result expression");
959                return self.evaluate(&branch.result, row_index);
960            }
961        }
962
963        // If no WHEN condition matched, evaluate ELSE clause (or return NULL)
964        if let Some(else_expr) = else_branch {
965            debug!("CASE: No WHEN matched, evaluating ELSE expression");
966            self.evaluate(else_expr, row_index)
967        } else {
968            debug!("CASE: No WHEN matched and no ELSE, returning NULL");
969            Ok(DataValue::Null)
970        }
971    }
972
973    /// Helper method to evaluate an expression as a boolean (for CASE WHEN conditions)
974    fn evaluate_condition_as_bool(
975        &mut self,
976        expr: &SqlExpression,
977        row_index: usize,
978    ) -> Result<bool> {
979        let value = self.evaluate(expr, row_index)?;
980
981        match value {
982            DataValue::Boolean(b) => Ok(b),
983            DataValue::Integer(i) => Ok(i != 0),
984            DataValue::Float(f) => Ok(f != 0.0),
985            DataValue::Null => Ok(false),
986            DataValue::String(s) => Ok(!s.is_empty()),
987            DataValue::InternedString(s) => Ok(!s.is_empty()),
988            _ => Ok(true), // Other types are considered truthy
989        }
990    }
991}
992
993#[cfg(test)]
994mod tests {
995    use super::*;
996    use crate::data::datatable::{DataColumn, DataRow};
997
998    fn create_test_table() -> DataTable {
999        let mut table = DataTable::new("test");
1000        table.add_column(DataColumn::new("a"));
1001        table.add_column(DataColumn::new("b"));
1002        table.add_column(DataColumn::new("c"));
1003
1004        table
1005            .add_row(DataRow::new(vec![
1006                DataValue::Integer(10),
1007                DataValue::Float(2.5),
1008                DataValue::Integer(4),
1009            ]))
1010            .unwrap();
1011
1012        table
1013    }
1014
1015    #[test]
1016    fn test_evaluate_column() {
1017        let table = create_test_table();
1018        let mut evaluator = ArithmeticEvaluator::new(&table);
1019
1020        let expr = SqlExpression::Column("a".to_string());
1021        let result = evaluator.evaluate(&expr, 0).unwrap();
1022        assert_eq!(result, DataValue::Integer(10));
1023    }
1024
1025    #[test]
1026    fn test_evaluate_number_literal() {
1027        let table = create_test_table();
1028        let mut evaluator = ArithmeticEvaluator::new(&table);
1029
1030        let expr = SqlExpression::NumberLiteral("42".to_string());
1031        let result = evaluator.evaluate(&expr, 0).unwrap();
1032        assert_eq!(result, DataValue::Integer(42));
1033
1034        let expr = SqlExpression::NumberLiteral("3.14".to_string());
1035        let result = evaluator.evaluate(&expr, 0).unwrap();
1036        assert_eq!(result, DataValue::Float(3.14));
1037    }
1038
1039    #[test]
1040    fn test_add_values() {
1041        let table = create_test_table();
1042        let mut evaluator = ArithmeticEvaluator::new(&table);
1043
1044        // Integer + Integer
1045        let result = evaluator
1046            .add_values(&DataValue::Integer(5), &DataValue::Integer(3))
1047            .unwrap();
1048        assert_eq!(result, DataValue::Integer(8));
1049
1050        // Integer + Float
1051        let result = evaluator
1052            .add_values(&DataValue::Integer(5), &DataValue::Float(2.5))
1053            .unwrap();
1054        assert_eq!(result, DataValue::Float(7.5));
1055    }
1056
1057    #[test]
1058    fn test_multiply_values() {
1059        let table = create_test_table();
1060        let mut evaluator = ArithmeticEvaluator::new(&table);
1061
1062        // Integer * Float
1063        let result = evaluator
1064            .multiply_values(&DataValue::Integer(4), &DataValue::Float(2.5))
1065            .unwrap();
1066        assert_eq!(result, DataValue::Float(10.0));
1067    }
1068
1069    #[test]
1070    fn test_divide_values() {
1071        let table = create_test_table();
1072        let mut evaluator = ArithmeticEvaluator::new(&table);
1073
1074        // Exact division
1075        let result = evaluator
1076            .divide_values(&DataValue::Integer(10), &DataValue::Integer(2))
1077            .unwrap();
1078        assert_eq!(result, DataValue::Integer(5));
1079
1080        // Non-exact division
1081        let result = evaluator
1082            .divide_values(&DataValue::Integer(10), &DataValue::Integer(3))
1083            .unwrap();
1084        assert_eq!(result, DataValue::Float(10.0 / 3.0));
1085    }
1086
1087    #[test]
1088    fn test_division_by_zero() {
1089        let table = create_test_table();
1090        let mut evaluator = ArithmeticEvaluator::new(&table);
1091
1092        let result = evaluator.divide_values(&DataValue::Integer(10), &DataValue::Integer(0));
1093        assert!(result.is_err());
1094        assert!(result.unwrap_err().to_string().contains("Division by zero"));
1095    }
1096
1097    #[test]
1098    fn test_binary_op_expression() {
1099        let table = create_test_table();
1100        let mut evaluator = ArithmeticEvaluator::new(&table);
1101
1102        // a * b where a=10, b=2.5
1103        let expr = SqlExpression::BinaryOp {
1104            left: Box::new(SqlExpression::Column("a".to_string())),
1105            op: "*".to_string(),
1106            right: Box::new(SqlExpression::Column("b".to_string())),
1107        };
1108
1109        let result = evaluator.evaluate(&expr, 0).unwrap();
1110        assert_eq!(result, DataValue::Float(25.0));
1111    }
1112}