sql_cli/data/
arithmetic_evaluator.rs

1use crate::config::global::get_date_notation;
2use crate::data::data_view::DataView;
3use crate::data::datatable::{DataTable, DataValue};
4use crate::data::value_comparisons::compare_with_op;
5use crate::sql::aggregates::AggregateRegistry;
6use crate::sql::functions::FunctionRegistry;
7use crate::sql::recursive_parser::{SqlExpression, WindowSpec};
8use crate::sql::window_context::WindowContext;
9use anyhow::{anyhow, Result};
10use std::collections::HashMap;
11use std::sync::Arc;
12use tracing::debug;
13
14/// Evaluates SQL expressions to compute `DataValues` (for SELECT clauses)
15/// This is different from `RecursiveWhereEvaluator` which returns boolean
16pub struct ArithmeticEvaluator<'a> {
17    table: &'a DataTable,
18    date_notation: String,
19    function_registry: Arc<FunctionRegistry>,
20    aggregate_registry: Arc<AggregateRegistry>,
21    visible_rows: Option<Vec<usize>>, // For aggregate functions on filtered views
22    window_contexts: HashMap<String, Arc<WindowContext>>, // Cache window contexts by spec
23    table_aliases: HashMap<String, String>, // Map alias -> table name for qualified columns
24}
25
26impl<'a> ArithmeticEvaluator<'a> {
27    #[must_use]
28    pub fn new(table: &'a DataTable) -> Self {
29        Self {
30            table,
31            date_notation: get_date_notation(),
32            function_registry: Arc::new(FunctionRegistry::new()),
33            aggregate_registry: Arc::new(AggregateRegistry::new()),
34            visible_rows: None,
35            window_contexts: HashMap::new(),
36            table_aliases: HashMap::new(),
37        }
38    }
39
40    #[must_use]
41    pub fn with_date_notation(table: &'a DataTable, date_notation: String) -> Self {
42        Self {
43            table,
44            date_notation,
45            function_registry: Arc::new(FunctionRegistry::new()),
46            aggregate_registry: Arc::new(AggregateRegistry::new()),
47            visible_rows: None,
48            window_contexts: HashMap::new(),
49            table_aliases: HashMap::new(),
50        }
51    }
52
53    /// Set visible rows for aggregate functions (for filtered views)
54    #[must_use]
55    pub fn with_visible_rows(mut self, rows: Vec<usize>) -> Self {
56        self.visible_rows = Some(rows);
57        self
58    }
59
60    /// Set table aliases for qualified column resolution
61    #[must_use]
62    pub fn with_table_aliases(mut self, aliases: HashMap<String, String>) -> Self {
63        self.table_aliases = aliases;
64        self
65    }
66
67    #[must_use]
68    pub fn with_date_notation_and_registry(
69        table: &'a DataTable,
70        date_notation: String,
71        function_registry: Arc<FunctionRegistry>,
72    ) -> Self {
73        Self {
74            table,
75            date_notation,
76            function_registry,
77            aggregate_registry: Arc::new(AggregateRegistry::new()),
78            visible_rows: None,
79            window_contexts: HashMap::new(),
80            table_aliases: HashMap::new(),
81        }
82    }
83
84    /// Find a column name similar to the given name using edit distance
85    fn find_similar_column(&self, name: &str) -> Option<String> {
86        let columns = self.table.column_names();
87        let mut best_match: Option<(String, usize)> = None;
88
89        for col in columns {
90            let distance = self.edit_distance(&col.to_lowercase(), &name.to_lowercase());
91            // Only suggest if distance is small (likely a typo)
92            // Allow up to 3 edits for longer names
93            let max_distance = if name.len() > 10 { 3 } else { 2 };
94            if distance <= max_distance {
95                match &best_match {
96                    None => best_match = Some((col, distance)),
97                    Some((_, best_dist)) if distance < *best_dist => {
98                        best_match = Some((col, distance));
99                    }
100                    _ => {}
101                }
102            }
103        }
104
105        best_match.map(|(name, _)| name)
106    }
107
108    /// Calculate Levenshtein edit distance between two strings
109    fn edit_distance(&self, s1: &str, s2: &str) -> usize {
110        // Use the shared implementation from string_methods
111        crate::sql::functions::string_methods::EditDistanceFunction::calculate_edit_distance(s1, s2)
112    }
113
114    /// Evaluate an SQL expression to produce a `DataValue`
115    pub fn evaluate(&mut self, expr: &SqlExpression, row_index: usize) -> Result<DataValue> {
116        debug!(
117            "ArithmeticEvaluator: evaluating {:?} for row {}",
118            expr, row_index
119        );
120
121        match expr {
122            SqlExpression::Column(column_name) => self.evaluate_column(column_name, row_index),
123            SqlExpression::StringLiteral(s) => Ok(DataValue::String(s.clone())),
124            SqlExpression::BooleanLiteral(b) => Ok(DataValue::Boolean(*b)),
125            SqlExpression::NumberLiteral(n) => self.evaluate_number_literal(n),
126            SqlExpression::Null => Ok(DataValue::Null),
127            SqlExpression::BinaryOp { left, op, right } => {
128                self.evaluate_binary_op(left, op, right, row_index)
129            }
130            SqlExpression::FunctionCall {
131                name,
132                args,
133                distinct,
134            } => self.evaluate_function_with_distinct(name, args, *distinct, row_index),
135            SqlExpression::WindowFunction {
136                name,
137                args,
138                window_spec,
139            } => self.evaluate_window_function(name, args, window_spec, row_index),
140            SqlExpression::MethodCall {
141                object,
142                method,
143                args,
144            } => self.evaluate_method_call(object, method, args, row_index),
145            SqlExpression::ChainedMethodCall { base, method, args } => {
146                // Evaluate the base expression first, then apply the method
147                let base_value = self.evaluate(base, row_index)?;
148                self.evaluate_method_on_value(&base_value, method, args, row_index)
149            }
150            SqlExpression::CaseExpression {
151                when_branches,
152                else_branch,
153            } => self.evaluate_case_expression(when_branches, else_branch, row_index),
154            _ => Err(anyhow!(
155                "Unsupported expression type for arithmetic evaluation: {:?}",
156                expr
157            )),
158        }
159    }
160
161    /// Evaluate a column reference
162    fn evaluate_column(&self, column_name: &str, row_index: usize) -> Result<DataValue> {
163        // First try to resolve qualified column names (table.column or alias.column)
164        let resolved_column = if column_name.contains('.') {
165            // Split on last dot to handle cases like "schema.table.column"
166            if let Some(dot_pos) = column_name.rfind('.') {
167                let _table_or_alias = &column_name[..dot_pos];
168                let col_name = &column_name[dot_pos + 1..];
169
170                // For now, just use the column name part
171                // In the future, we could validate the table/alias part
172                debug!(
173                    "Resolving qualified column: {} -> {}",
174                    column_name, col_name
175                );
176                col_name.to_string()
177            } else {
178                column_name.to_string()
179            }
180        } else {
181            column_name.to_string()
182        };
183
184        let col_index = if let Some(idx) = self.table.get_column_index(&resolved_column) {
185            idx
186        } else if resolved_column != column_name {
187            // If not found, try the original name
188            if let Some(idx) = self.table.get_column_index(column_name) {
189                idx
190            } else {
191                let suggestion = self.find_similar_column(&resolved_column);
192                return Err(match suggestion {
193                    Some(similar) => anyhow!(
194                        "Column '{}' not found. Did you mean '{}'?",
195                        column_name,
196                        similar
197                    ),
198                    None => anyhow!("Column '{}' not found", column_name),
199                });
200            }
201        } else {
202            let suggestion = self.find_similar_column(&resolved_column);
203            return Err(match suggestion {
204                Some(similar) => anyhow!(
205                    "Column '{}' not found. Did you mean '{}'?",
206                    column_name,
207                    similar
208                ),
209                None => anyhow!("Column '{}' not found", column_name),
210            });
211        };
212
213        if row_index >= self.table.row_count() {
214            return Err(anyhow!("Row index {} out of bounds", row_index));
215        }
216
217        let row = self
218            .table
219            .get_row(row_index)
220            .ok_or_else(|| anyhow!("Row {} not found", row_index))?;
221
222        let value = row
223            .get(col_index)
224            .ok_or_else(|| anyhow!("Column index {} out of bounds for row", col_index))?;
225
226        Ok(value.clone())
227    }
228
229    /// Evaluate a number literal (handles both integers and floats)
230    fn evaluate_number_literal(&self, number_str: &str) -> Result<DataValue> {
231        // Try to parse as integer first
232        if let Ok(int_val) = number_str.parse::<i64>() {
233            return Ok(DataValue::Integer(int_val));
234        }
235
236        // If that fails, try as float
237        if let Ok(float_val) = number_str.parse::<f64>() {
238            return Ok(DataValue::Float(float_val));
239        }
240
241        Err(anyhow!("Invalid number literal: {}", number_str))
242    }
243
244    /// Evaluate a binary operation (arithmetic)
245    fn evaluate_binary_op(
246        &mut self,
247        left: &SqlExpression,
248        op: &str,
249        right: &SqlExpression,
250        row_index: usize,
251    ) -> Result<DataValue> {
252        let left_val = self.evaluate(left, row_index)?;
253        let right_val = self.evaluate(right, row_index)?;
254
255        debug!(
256            "ArithmeticEvaluator: {} {} {}",
257            self.format_value(&left_val),
258            op,
259            self.format_value(&right_val)
260        );
261
262        match op {
263            "+" => self.add_values(&left_val, &right_val),
264            "-" => self.subtract_values(&left_val, &right_val),
265            "*" => self.multiply_values(&left_val, &right_val),
266            "/" => self.divide_values(&left_val, &right_val),
267            "%" => {
268                // Modulo operator - call MOD function
269                let args = vec![left.clone(), right.clone()];
270                self.evaluate_function("MOD", &args, row_index)
271            }
272            // Comparison operators (return boolean results)
273            // Use centralized comparison logic for consistency
274            ">" | "<" | ">=" | "<=" | "=" | "!=" | "<>" => {
275                let result = compare_with_op(&left_val, &right_val, op, false);
276                Ok(DataValue::Boolean(result))
277            }
278            // IS NULL / IS NOT NULL operators
279            "IS NULL" => Ok(DataValue::Boolean(matches!(left_val, DataValue::Null))),
280            "IS NOT NULL" => Ok(DataValue::Boolean(!matches!(left_val, DataValue::Null))),
281            // Logical operators
282            "AND" => {
283                let left_bool = self.to_bool(&left_val)?;
284                let right_bool = self.to_bool(&right_val)?;
285                Ok(DataValue::Boolean(left_bool && right_bool))
286            }
287            "OR" => {
288                let left_bool = self.to_bool(&left_val)?;
289                let right_bool = self.to_bool(&right_val)?;
290                Ok(DataValue::Boolean(left_bool || right_bool))
291            }
292            _ => Err(anyhow!("Unsupported arithmetic operator: {}", op)),
293        }
294    }
295
296    /// Add two `DataValues` with type coercion
297    fn add_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
298        // NULL handling - any operation with NULL returns NULL
299        if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
300            return Ok(DataValue::Null);
301        }
302
303        match (left, right) {
304            (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a + b)),
305            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 + b)),
306            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a + *b as f64)),
307            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a + b)),
308            _ => Err(anyhow!("Cannot add {:?} and {:?}", left, right)),
309        }
310    }
311
312    /// Subtract two `DataValues` with type coercion
313    fn subtract_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
314        // NULL handling - any operation with NULL returns NULL
315        if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
316            return Ok(DataValue::Null);
317        }
318
319        match (left, right) {
320            (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a - b)),
321            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 - b)),
322            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a - *b as f64)),
323            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a - b)),
324            _ => Err(anyhow!("Cannot subtract {:?} and {:?}", left, right)),
325        }
326    }
327
328    /// Multiply two `DataValues` with type coercion
329    fn multiply_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
330        // NULL handling - any operation with NULL returns NULL
331        if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
332            return Ok(DataValue::Null);
333        }
334
335        match (left, right) {
336            (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a * b)),
337            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 * b)),
338            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a * *b as f64)),
339            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a * b)),
340            _ => Err(anyhow!("Cannot multiply {:?} and {:?}", left, right)),
341        }
342    }
343
344    /// Divide two `DataValues` with type coercion
345    fn divide_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
346        // NULL handling - any operation with NULL returns NULL
347        if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
348            return Ok(DataValue::Null);
349        }
350
351        // Check for division by zero first
352        let is_zero = match right {
353            DataValue::Integer(0) => true,
354            DataValue::Float(f) if *f == 0.0 => true, // Only check for exact zero, not epsilon
355            _ => false,
356        };
357
358        if is_zero {
359            return Err(anyhow!("Division by zero"));
360        }
361
362        match (left, right) {
363            (DataValue::Integer(a), DataValue::Integer(b)) => {
364                // Integer division - if result is exact, keep as int, otherwise promote to float
365                if a % b == 0 {
366                    Ok(DataValue::Integer(a / b))
367                } else {
368                    Ok(DataValue::Float(*a as f64 / *b as f64))
369                }
370            }
371            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 / b)),
372            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a / *b as f64)),
373            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a / b)),
374            _ => Err(anyhow!("Cannot divide {:?} and {:?}", left, right)),
375        }
376    }
377
378    /// Format a `DataValue` for debug output
379    fn format_value(&self, value: &DataValue) -> String {
380        match value {
381            DataValue::Integer(i) => i.to_string(),
382            DataValue::Float(f) => f.to_string(),
383            DataValue::String(s) => format!("'{s}'"),
384            _ => format!("{value:?}"),
385        }
386    }
387
388    /// Convert a DataValue to boolean for logical operations
389    fn to_bool(&self, value: &DataValue) -> Result<bool> {
390        match value {
391            DataValue::Boolean(b) => Ok(*b),
392            DataValue::Integer(i) => Ok(*i != 0),
393            DataValue::Float(f) => Ok(*f != 0.0),
394            DataValue::Null => Ok(false),
395            _ => Err(anyhow!("Cannot convert {:?} to boolean", value)),
396        }
397    }
398
399    /// Evaluate a function call
400    fn evaluate_function_with_distinct(
401        &mut self,
402        name: &str,
403        args: &[SqlExpression],
404        distinct: bool,
405        row_index: usize,
406    ) -> Result<DataValue> {
407        // If DISTINCT is specified, handle it specially for aggregate functions
408        if distinct {
409            let name_upper = name.to_uppercase();
410
411            // DISTINCT is only valid for aggregate functions
412            if name_upper == "COUNT"
413                || name_upper == "SUM"
414                || name_upper == "AVG"
415                || name_upper == "MIN"
416                || name_upper == "MAX"
417            {
418                return self.evaluate_aggregate_distinct(&name_upper, args, row_index);
419            } else {
420                return Err(anyhow!(
421                    "DISTINCT can only be used with aggregate functions"
422                ));
423            }
424        }
425
426        // Otherwise, use the regular evaluation
427        self.evaluate_function(name, args, row_index)
428    }
429
430    fn evaluate_aggregate_distinct(
431        &mut self,
432        name: &str,
433        args: &[SqlExpression],
434        _row_index: usize,
435    ) -> Result<DataValue> {
436        use std::collections::HashSet;
437
438        if args.is_empty() {
439            return Err(anyhow!("{} DISTINCT requires at least one argument", name));
440        }
441
442        // Determine which rows to process
443        let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
444            visible.clone()
445        } else {
446            (0..self.table.rows.len()).collect()
447        };
448
449        // Collect unique values
450        let mut unique_values = HashSet::new();
451        let mut numeric_values = Vec::new();
452
453        for row_idx in &rows_to_process {
454            // Evaluate the expression for this row
455            let value = self.evaluate(&args[0], *row_idx)?;
456
457            // Skip NULL values
458            if matches!(value, DataValue::Null) {
459                continue;
460            }
461
462            // Convert to string for uniqueness check
463            let value_str = match &value {
464                DataValue::String(s) => s.clone(),
465                DataValue::InternedString(s) => s.to_string(),
466                DataValue::Integer(i) => i.to_string(),
467                DataValue::Float(f) => f.to_string(),
468                DataValue::Boolean(b) => b.to_string(),
469                DataValue::DateTime(dt) => dt.to_string(),
470                DataValue::Null => continue,
471            };
472
473            // Only process if we haven't seen this value before
474            if unique_values.insert(value_str) {
475                // For numeric aggregates, collect the numeric value
476                if name != "COUNT" {
477                    match value {
478                        DataValue::Integer(i) => numeric_values.push(i as f64),
479                        DataValue::Float(f) => numeric_values.push(f),
480                        _ => {} // Skip non-numeric for SUM/AVG
481                    }
482                }
483            }
484        }
485
486        // Calculate the result based on the aggregate function
487        match name {
488            "COUNT" => Ok(DataValue::Integer(unique_values.len() as i64)),
489            "SUM" => {
490                if numeric_values.is_empty() {
491                    Ok(DataValue::Null)
492                } else {
493                    let sum: f64 = numeric_values.iter().sum();
494                    if sum.fract() == 0.0 && sum.abs() < 1e10 {
495                        Ok(DataValue::Integer(sum as i64))
496                    } else {
497                        Ok(DataValue::Float(sum))
498                    }
499                }
500            }
501            "AVG" => {
502                if numeric_values.is_empty() {
503                    Ok(DataValue::Null)
504                } else {
505                    let sum: f64 = numeric_values.iter().sum();
506                    Ok(DataValue::Float(sum / numeric_values.len() as f64))
507                }
508            }
509            "MIN" => {
510                if numeric_values.is_empty() {
511                    Ok(DataValue::Null)
512                } else {
513                    let min = numeric_values.iter().fold(f64::INFINITY, |a, &b| a.min(b));
514                    if min.fract() == 0.0 && min.abs() < 1e10 {
515                        Ok(DataValue::Integer(min as i64))
516                    } else {
517                        Ok(DataValue::Float(min))
518                    }
519                }
520            }
521            "MAX" => {
522                if numeric_values.is_empty() {
523                    Ok(DataValue::Null)
524                } else {
525                    let max = numeric_values
526                        .iter()
527                        .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
528                    if max.fract() == 0.0 && max.abs() < 1e10 {
529                        Ok(DataValue::Integer(max as i64))
530                    } else {
531                        Ok(DataValue::Float(max))
532                    }
533                }
534            }
535            _ => Err(anyhow!("Unsupported DISTINCT aggregate: {}", name)),
536        }
537    }
538
539    fn evaluate_function(
540        &mut self,
541        name: &str,
542        args: &[SqlExpression],
543        row_index: usize,
544    ) -> Result<DataValue> {
545        // Check if this is an aggregate function
546        let name_upper = name.to_uppercase();
547
548        // Handle COUNT(*) special case
549        if name_upper == "COUNT" && args.len() == 1 {
550            match &args[0] {
551                SqlExpression::Column(col) if col == "*" => {
552                    // COUNT(*) - count all rows (or visible rows if filtered)
553                    let count = if let Some(ref visible) = self.visible_rows {
554                        visible.len() as i64
555                    } else {
556                        self.table.rows.len() as i64
557                    };
558                    return Ok(DataValue::Integer(count));
559                }
560                SqlExpression::StringLiteral(s) if s == "*" => {
561                    // COUNT(*) parsed as StringLiteral
562                    let count = if let Some(ref visible) = self.visible_rows {
563                        visible.len() as i64
564                    } else {
565                        self.table.rows.len() as i64
566                    };
567                    return Ok(DataValue::Integer(count));
568                }
569                _ => {
570                    // COUNT(column) - will be handled below
571                }
572            }
573        }
574
575        // Check aggregate registry
576        if self.aggregate_registry.get(&name_upper).is_some() {
577            // Determine which rows to process first
578            let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
579                visible.clone()
580            } else {
581                (0..self.table.rows.len()).collect()
582            };
583
584            // Evaluate arguments first if needed (to avoid borrow issues)
585            let values = if !args.is_empty()
586                && !(args.len() == 1 && matches!(&args[0], SqlExpression::Column(c) if c == "*"))
587            {
588                // Evaluate the argument expression for each row
589                let mut vals = Vec::new();
590                for &row_idx in &rows_to_process {
591                    let value = self.evaluate(&args[0], row_idx)?;
592                    vals.push(value);
593                }
594                Some(vals)
595            } else {
596                None
597            };
598
599            // Now get the aggregate function and process
600            let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
601            let mut state = agg_func.init();
602
603            if let Some(values) = values {
604                // Use evaluated values
605                for value in &values {
606                    agg_func.accumulate(&mut state, value)?;
607                }
608            } else {
609                // COUNT(*) case
610                for _ in &rows_to_process {
611                    agg_func.accumulate(&mut state, &DataValue::Integer(1))?;
612                }
613            }
614
615            return Ok(agg_func.finalize(state));
616        }
617
618        // First check if this function exists in the registry
619        if self.function_registry.get(name).is_some() {
620            // Evaluate all arguments first to avoid borrow issues
621            let mut evaluated_args = Vec::new();
622            for arg in args {
623                evaluated_args.push(self.evaluate(arg, row_index)?);
624            }
625
626            // Get the function and call it
627            let func = self.function_registry.get(name).unwrap();
628            return func.evaluate(&evaluated_args);
629        }
630
631        // If not in registry, return error for unknown function
632        Err(anyhow!("Unknown function: {}", name))
633    }
634
635    /// Get or create a WindowContext for the given specification
636    fn get_or_create_window_context(&mut self, spec: &WindowSpec) -> Result<Arc<WindowContext>> {
637        // Create a key for caching based on the spec
638        let key = format!("{:?}", spec);
639
640        if let Some(context) = self.window_contexts.get(&key) {
641            return Ok(Arc::clone(context));
642        }
643
644        // Create a DataView from the table (with visible rows if filtered)
645        let data_view = if let Some(ref _visible_rows) = self.visible_rows {
646            // Create a filtered view
647            let view = DataView::new(Arc::new(self.table.clone()));
648            // Apply filtering based on visible rows
649            // Note: This is a simplified approach - in production we'd need proper filtering
650            view
651        } else {
652            DataView::new(Arc::new(self.table.clone()))
653        };
654
655        // Create the WindowContext
656        let context = WindowContext::new(
657            Arc::new(data_view),
658            spec.partition_by.clone(),
659            spec.order_by.clone(),
660        )?;
661
662        let context = Arc::new(context);
663        self.window_contexts.insert(key, Arc::clone(&context));
664        Ok(context)
665    }
666
667    /// Evaluate a window function
668    fn evaluate_window_function(
669        &mut self,
670        name: &str,
671        args: &[SqlExpression],
672        spec: &WindowSpec,
673        row_index: usize,
674    ) -> Result<DataValue> {
675        let context = self.get_or_create_window_context(spec)?;
676        let name_upper = name.to_uppercase();
677
678        match name_upper.as_str() {
679            "LAG" => {
680                // LAG(column, offset, default)
681                if args.is_empty() {
682                    return Err(anyhow!("LAG requires at least 1 argument"));
683                }
684
685                // Get column name
686                let column = match &args[0] {
687                    SqlExpression::Column(col) => col.clone(),
688                    _ => return Err(anyhow!("LAG first argument must be a column")),
689                };
690
691                // Get offset (default 1)
692                let offset = if args.len() > 1 {
693                    match self.evaluate(&args[1], row_index)? {
694                        DataValue::Integer(i) => i as i32,
695                        _ => return Err(anyhow!("LAG offset must be an integer")),
696                    }
697                } else {
698                    1
699                };
700
701                // Get value at offset
702                Ok(context
703                    .get_offset_value(row_index, -offset, &column)
704                    .unwrap_or(DataValue::Null))
705            }
706            "LEAD" => {
707                // LEAD(column, offset, default)
708                if args.is_empty() {
709                    return Err(anyhow!("LEAD requires at least 1 argument"));
710                }
711
712                // Get column name
713                let column = match &args[0] {
714                    SqlExpression::Column(col) => col.clone(),
715                    _ => return Err(anyhow!("LEAD first argument must be a column")),
716                };
717
718                // Get offset (default 1)
719                let offset = if args.len() > 1 {
720                    match self.evaluate(&args[1], row_index)? {
721                        DataValue::Integer(i) => i as i32,
722                        _ => return Err(anyhow!("LEAD offset must be an integer")),
723                    }
724                } else {
725                    1
726                };
727
728                // Get value at offset
729                Ok(context
730                    .get_offset_value(row_index, offset, &column)
731                    .unwrap_or(DataValue::Null))
732            }
733            "ROW_NUMBER" => {
734                // ROW_NUMBER() - no arguments
735                Ok(DataValue::Integer(context.get_row_number(row_index) as i64))
736            }
737            "FIRST_VALUE" => {
738                // FIRST_VALUE(column)
739                if args.is_empty() {
740                    return Err(anyhow!("FIRST_VALUE requires 1 argument"));
741                }
742
743                let column = match &args[0] {
744                    SqlExpression::Column(col) => col.clone(),
745                    _ => return Err(anyhow!("FIRST_VALUE argument must be a column")),
746                };
747
748                Ok(context
749                    .get_first_value(row_index, &column)
750                    .unwrap_or(DataValue::Null))
751            }
752            "LAST_VALUE" => {
753                // LAST_VALUE(column)
754                if args.is_empty() {
755                    return Err(anyhow!("LAST_VALUE requires 1 argument"));
756                }
757
758                let column = match &args[0] {
759                    SqlExpression::Column(col) => col.clone(),
760                    _ => return Err(anyhow!("LAST_VALUE argument must be a column")),
761                };
762
763                Ok(context
764                    .get_last_value(row_index, &column)
765                    .unwrap_or(DataValue::Null))
766            }
767            "SUM" => {
768                // SUM(column) OVER (PARTITION BY ...)
769                if args.is_empty() {
770                    return Err(anyhow!("SUM requires 1 argument"));
771                }
772
773                let column = match &args[0] {
774                    SqlExpression::Column(col) => col.clone(),
775                    _ => return Err(anyhow!("SUM argument must be a column")),
776                };
777
778                Ok(context
779                    .get_partition_sum(row_index, &column)
780                    .unwrap_or(DataValue::Null))
781            }
782            "COUNT" => {
783                // COUNT(*) or COUNT(column) OVER (PARTITION BY ...)
784                // Note: In window functions, COUNT(*) seems to come with no args
785                if args.is_empty() {
786                    // COUNT(*) OVER (...) - count all rows in partition
787                    Ok(context
788                        .get_partition_count(row_index, None)
789                        .unwrap_or(DataValue::Null))
790                } else {
791                    // Check for COUNT(*)
792                    let column = match &args[0] {
793                        SqlExpression::Column(col) => {
794                            if col == "*" {
795                                // COUNT(*) - count all rows in partition
796                                return Ok(context
797                                    .get_partition_count(row_index, None)
798                                    .unwrap_or(DataValue::Null));
799                            }
800                            col.clone()
801                        }
802                        SqlExpression::StringLiteral(s) if s == "*" => {
803                            // COUNT(*) as StringLiteral (how parser sends it)
804                            return Ok(context
805                                .get_partition_count(row_index, None)
806                                .unwrap_or(DataValue::Null));
807                        }
808                        _ => return Err(anyhow!("COUNT argument must be a column or *")),
809                    };
810
811                    // COUNT(column) - count non-null values
812                    Ok(context
813                        .get_partition_count(row_index, Some(&column))
814                        .unwrap_or(DataValue::Null))
815                }
816            }
817            _ => Err(anyhow!("Unknown window function: {}", name)),
818        }
819    }
820
821    /// Evaluate a method call on a column (e.g., `column.Trim()`)
822    fn evaluate_method_call(
823        &mut self,
824        object: &str,
825        method: &str,
826        args: &[SqlExpression],
827        row_index: usize,
828    ) -> Result<DataValue> {
829        // Get column value
830        let col_index = self.table.get_column_index(object).ok_or_else(|| {
831            let suggestion = self.find_similar_column(object);
832            match suggestion {
833                Some(similar) => {
834                    anyhow!("Column '{}' not found. Did you mean '{}'?", object, similar)
835                }
836                None => anyhow!("Column '{}' not found", object),
837            }
838        })?;
839
840        let cell_value = self.table.get_value(row_index, col_index).cloned();
841
842        self.evaluate_method_on_value(
843            &cell_value.unwrap_or(DataValue::Null),
844            method,
845            args,
846            row_index,
847        )
848    }
849
850    /// Evaluate a method on a value
851    fn evaluate_method_on_value(
852        &mut self,
853        value: &DataValue,
854        method: &str,
855        args: &[SqlExpression],
856        row_index: usize,
857    ) -> Result<DataValue> {
858        // First, try to proxy the method through the function registry
859        // Many string methods have corresponding functions (TRIM, LENGTH, CONTAINS, etc.)
860
861        // Map method names to function names (case-insensitive matching)
862        let function_name = match method.to_lowercase().as_str() {
863            "trim" => "TRIM",
864            "trimstart" | "trimbegin" => "TRIMSTART",
865            "trimend" => "TRIMEND",
866            "length" | "len" => "LENGTH",
867            "contains" => "CONTAINS",
868            "startswith" => "STARTSWITH",
869            "endswith" => "ENDSWITH",
870            "indexof" => "INDEXOF",
871            _ => method, // Try the method name as-is
872        };
873
874        // Check if we have this function in the registry
875        if self.function_registry.get(function_name).is_some() {
876            debug!(
877                "Proxying method '{}' through function registry as '{}'",
878                method, function_name
879            );
880
881            // Prepare arguments: receiver is the first argument, followed by method args
882            let mut func_args = vec![value.clone()];
883
884            // Evaluate method arguments and add them
885            for arg in args {
886                func_args.push(self.evaluate(arg, row_index)?);
887            }
888
889            // Get the function and call it
890            let func = self.function_registry.get(function_name).unwrap();
891            return func.evaluate(&func_args);
892        }
893
894        // If not in registry, the method is not supported
895        // All methods should be registered in the function registry
896        Err(anyhow!(
897            "Method '{}' not found. It should be registered in the function registry.",
898            method
899        ))
900    }
901
902    /// Evaluate a CASE expression
903    fn evaluate_case_expression(
904        &mut self,
905        when_branches: &[crate::sql::recursive_parser::WhenBranch],
906        else_branch: &Option<Box<SqlExpression>>,
907        row_index: usize,
908    ) -> Result<DataValue> {
909        debug!(
910            "ArithmeticEvaluator: evaluating CASE expression for row {}",
911            row_index
912        );
913
914        // Evaluate each WHEN condition in order
915        for branch in when_branches {
916            // Evaluate the condition as a boolean
917            let condition_result = self.evaluate_condition_as_bool(&branch.condition, row_index)?;
918
919            if condition_result {
920                debug!("CASE: WHEN condition matched, evaluating result expression");
921                return self.evaluate(&branch.result, row_index);
922            }
923        }
924
925        // If no WHEN condition matched, evaluate ELSE clause (or return NULL)
926        if let Some(else_expr) = else_branch {
927            debug!("CASE: No WHEN matched, evaluating ELSE expression");
928            self.evaluate(else_expr, row_index)
929        } else {
930            debug!("CASE: No WHEN matched and no ELSE, returning NULL");
931            Ok(DataValue::Null)
932        }
933    }
934
935    /// Helper method to evaluate an expression as a boolean (for CASE WHEN conditions)
936    fn evaluate_condition_as_bool(
937        &mut self,
938        expr: &SqlExpression,
939        row_index: usize,
940    ) -> Result<bool> {
941        let value = self.evaluate(expr, row_index)?;
942
943        match value {
944            DataValue::Boolean(b) => Ok(b),
945            DataValue::Integer(i) => Ok(i != 0),
946            DataValue::Float(f) => Ok(f != 0.0),
947            DataValue::Null => Ok(false),
948            DataValue::String(s) => Ok(!s.is_empty()),
949            DataValue::InternedString(s) => Ok(!s.is_empty()),
950            _ => Ok(true), // Other types are considered truthy
951        }
952    }
953}
954
955#[cfg(test)]
956mod tests {
957    use super::*;
958    use crate::data::datatable::{DataColumn, DataRow};
959
960    fn create_test_table() -> DataTable {
961        let mut table = DataTable::new("test");
962        table.add_column(DataColumn::new("a"));
963        table.add_column(DataColumn::new("b"));
964        table.add_column(DataColumn::new("c"));
965
966        table
967            .add_row(DataRow::new(vec![
968                DataValue::Integer(10),
969                DataValue::Float(2.5),
970                DataValue::Integer(4),
971            ]))
972            .unwrap();
973
974        table
975    }
976
977    #[test]
978    fn test_evaluate_column() {
979        let table = create_test_table();
980        let mut evaluator = ArithmeticEvaluator::new(&table);
981
982        let expr = SqlExpression::Column("a".to_string());
983        let result = evaluator.evaluate(&expr, 0).unwrap();
984        assert_eq!(result, DataValue::Integer(10));
985    }
986
987    #[test]
988    fn test_evaluate_number_literal() {
989        let table = create_test_table();
990        let mut evaluator = ArithmeticEvaluator::new(&table);
991
992        let expr = SqlExpression::NumberLiteral("42".to_string());
993        let result = evaluator.evaluate(&expr, 0).unwrap();
994        assert_eq!(result, DataValue::Integer(42));
995
996        let expr = SqlExpression::NumberLiteral("3.14".to_string());
997        let result = evaluator.evaluate(&expr, 0).unwrap();
998        assert_eq!(result, DataValue::Float(3.14));
999    }
1000
1001    #[test]
1002    fn test_add_values() {
1003        let table = create_test_table();
1004        let mut evaluator = ArithmeticEvaluator::new(&table);
1005
1006        // Integer + Integer
1007        let result = evaluator
1008            .add_values(&DataValue::Integer(5), &DataValue::Integer(3))
1009            .unwrap();
1010        assert_eq!(result, DataValue::Integer(8));
1011
1012        // Integer + Float
1013        let result = evaluator
1014            .add_values(&DataValue::Integer(5), &DataValue::Float(2.5))
1015            .unwrap();
1016        assert_eq!(result, DataValue::Float(7.5));
1017    }
1018
1019    #[test]
1020    fn test_multiply_values() {
1021        let table = create_test_table();
1022        let mut evaluator = ArithmeticEvaluator::new(&table);
1023
1024        // Integer * Float
1025        let result = evaluator
1026            .multiply_values(&DataValue::Integer(4), &DataValue::Float(2.5))
1027            .unwrap();
1028        assert_eq!(result, DataValue::Float(10.0));
1029    }
1030
1031    #[test]
1032    fn test_divide_values() {
1033        let table = create_test_table();
1034        let mut evaluator = ArithmeticEvaluator::new(&table);
1035
1036        // Exact division
1037        let result = evaluator
1038            .divide_values(&DataValue::Integer(10), &DataValue::Integer(2))
1039            .unwrap();
1040        assert_eq!(result, DataValue::Integer(5));
1041
1042        // Non-exact division
1043        let result = evaluator
1044            .divide_values(&DataValue::Integer(10), &DataValue::Integer(3))
1045            .unwrap();
1046        assert_eq!(result, DataValue::Float(10.0 / 3.0));
1047    }
1048
1049    #[test]
1050    fn test_division_by_zero() {
1051        let table = create_test_table();
1052        let mut evaluator = ArithmeticEvaluator::new(&table);
1053
1054        let result = evaluator.divide_values(&DataValue::Integer(10), &DataValue::Integer(0));
1055        assert!(result.is_err());
1056        assert!(result.unwrap_err().to_string().contains("Division by zero"));
1057    }
1058
1059    #[test]
1060    fn test_binary_op_expression() {
1061        let table = create_test_table();
1062        let mut evaluator = ArithmeticEvaluator::new(&table);
1063
1064        // a * b where a=10, b=2.5
1065        let expr = SqlExpression::BinaryOp {
1066            left: Box::new(SqlExpression::Column("a".to_string())),
1067            op: "*".to_string(),
1068            right: Box::new(SqlExpression::Column("b".to_string())),
1069        };
1070
1071        let result = evaluator.evaluate(&expr, 0).unwrap();
1072        assert_eq!(result, DataValue::Float(25.0));
1073    }
1074}