sql_cli/data/
arithmetic_evaluator.rs

1use crate::config::global::get_date_notation;
2use crate::data::data_view::DataView;
3use crate::data::datatable::{DataTable, DataValue};
4use crate::data::value_comparisons::compare_with_op;
5use crate::sql::aggregates::AggregateRegistry;
6use crate::sql::functions::FunctionRegistry;
7use crate::sql::recursive_parser::{SqlExpression, WindowSpec};
8use crate::sql::window_context::WindowContext;
9use anyhow::{anyhow, Result};
10use std::collections::HashMap;
11use std::sync::Arc;
12use tracing::debug;
13
14/// Evaluates SQL expressions to compute `DataValues` (for SELECT clauses)
15/// This is different from `RecursiveWhereEvaluator` which returns boolean
16pub struct ArithmeticEvaluator<'a> {
17    table: &'a DataTable,
18    date_notation: String,
19    function_registry: Arc<FunctionRegistry>,
20    aggregate_registry: Arc<AggregateRegistry>,
21    visible_rows: Option<Vec<usize>>, // For aggregate functions on filtered views
22    window_contexts: HashMap<String, Arc<WindowContext>>, // Cache window contexts by spec
23}
24
25impl<'a> ArithmeticEvaluator<'a> {
26    #[must_use]
27    pub fn new(table: &'a DataTable) -> Self {
28        Self {
29            table,
30            date_notation: get_date_notation(),
31            function_registry: Arc::new(FunctionRegistry::new()),
32            aggregate_registry: Arc::new(AggregateRegistry::new()),
33            visible_rows: None,
34            window_contexts: HashMap::new(),
35        }
36    }
37
38    #[must_use]
39    pub fn with_date_notation(table: &'a DataTable, date_notation: String) -> Self {
40        Self {
41            table,
42            date_notation,
43            function_registry: Arc::new(FunctionRegistry::new()),
44            aggregate_registry: Arc::new(AggregateRegistry::new()),
45            visible_rows: None,
46            window_contexts: HashMap::new(),
47        }
48    }
49
50    /// Set visible rows for aggregate functions (for filtered views)
51    #[must_use]
52    pub fn with_visible_rows(mut self, rows: Vec<usize>) -> Self {
53        self.visible_rows = Some(rows);
54        self
55    }
56
57    #[must_use]
58    pub fn with_date_notation_and_registry(
59        table: &'a DataTable,
60        date_notation: String,
61        function_registry: Arc<FunctionRegistry>,
62    ) -> Self {
63        Self {
64            table,
65            date_notation,
66            function_registry,
67            aggregate_registry: Arc::new(AggregateRegistry::new()),
68            visible_rows: None,
69            window_contexts: HashMap::new(),
70        }
71    }
72
73    /// Find a column name similar to the given name using edit distance
74    fn find_similar_column(&self, name: &str) -> Option<String> {
75        let columns = self.table.column_names();
76        let mut best_match: Option<(String, usize)> = None;
77
78        for col in columns {
79            let distance = self.edit_distance(&col.to_lowercase(), &name.to_lowercase());
80            // Only suggest if distance is small (likely a typo)
81            // Allow up to 3 edits for longer names
82            let max_distance = if name.len() > 10 { 3 } else { 2 };
83            if distance <= max_distance {
84                match &best_match {
85                    None => best_match = Some((col, distance)),
86                    Some((_, best_dist)) if distance < *best_dist => {
87                        best_match = Some((col, distance));
88                    }
89                    _ => {}
90                }
91            }
92        }
93
94        best_match.map(|(name, _)| name)
95    }
96
97    /// Calculate Levenshtein edit distance between two strings
98    fn edit_distance(&self, s1: &str, s2: &str) -> usize {
99        // Use the shared implementation from string_methods
100        crate::sql::functions::string_methods::EditDistanceFunction::calculate_edit_distance(s1, s2)
101    }
102
103    /// Evaluate an SQL expression to produce a `DataValue`
104    pub fn evaluate(&mut self, expr: &SqlExpression, row_index: usize) -> Result<DataValue> {
105        debug!(
106            "ArithmeticEvaluator: evaluating {:?} for row {}",
107            expr, row_index
108        );
109
110        match expr {
111            SqlExpression::Column(column_name) => self.evaluate_column(column_name, row_index),
112            SqlExpression::StringLiteral(s) => Ok(DataValue::String(s.clone())),
113            SqlExpression::BooleanLiteral(b) => Ok(DataValue::Boolean(*b)),
114            SqlExpression::NumberLiteral(n) => self.evaluate_number_literal(n),
115            SqlExpression::Null => Ok(DataValue::Null),
116            SqlExpression::BinaryOp { left, op, right } => {
117                self.evaluate_binary_op(left, op, right, row_index)
118            }
119            SqlExpression::FunctionCall {
120                name,
121                args,
122                distinct,
123            } => self.evaluate_function_with_distinct(name, args, *distinct, row_index),
124            SqlExpression::WindowFunction {
125                name,
126                args,
127                window_spec,
128            } => self.evaluate_window_function(name, args, window_spec, row_index),
129            SqlExpression::MethodCall {
130                object,
131                method,
132                args,
133            } => self.evaluate_method_call(object, method, args, row_index),
134            SqlExpression::ChainedMethodCall { base, method, args } => {
135                // Evaluate the base expression first, then apply the method
136                let base_value = self.evaluate(base, row_index)?;
137                self.evaluate_method_on_value(&base_value, method, args, row_index)
138            }
139            SqlExpression::CaseExpression {
140                when_branches,
141                else_branch,
142            } => self.evaluate_case_expression(when_branches, else_branch, row_index),
143            _ => Err(anyhow!(
144                "Unsupported expression type for arithmetic evaluation: {:?}",
145                expr
146            )),
147        }
148    }
149
150    /// Evaluate a column reference
151    fn evaluate_column(&self, column_name: &str, row_index: usize) -> Result<DataValue> {
152        let col_index = self.table.get_column_index(column_name).ok_or_else(|| {
153            let suggestion = self.find_similar_column(column_name);
154            match suggestion {
155                Some(similar) => anyhow!(
156                    "Column '{}' not found. Did you mean '{}'?",
157                    column_name,
158                    similar
159                ),
160                None => anyhow!("Column '{}' not found", column_name),
161            }
162        })?;
163
164        if row_index >= self.table.row_count() {
165            return Err(anyhow!("Row index {} out of bounds", row_index));
166        }
167
168        let row = self
169            .table
170            .get_row(row_index)
171            .ok_or_else(|| anyhow!("Row {} not found", row_index))?;
172
173        let value = row
174            .get(col_index)
175            .ok_or_else(|| anyhow!("Column index {} out of bounds for row", col_index))?;
176
177        Ok(value.clone())
178    }
179
180    /// Evaluate a number literal (handles both integers and floats)
181    fn evaluate_number_literal(&self, number_str: &str) -> Result<DataValue> {
182        // Try to parse as integer first
183        if let Ok(int_val) = number_str.parse::<i64>() {
184            return Ok(DataValue::Integer(int_val));
185        }
186
187        // If that fails, try as float
188        if let Ok(float_val) = number_str.parse::<f64>() {
189            return Ok(DataValue::Float(float_val));
190        }
191
192        Err(anyhow!("Invalid number literal: {}", number_str))
193    }
194
195    /// Evaluate a binary operation (arithmetic)
196    fn evaluate_binary_op(
197        &mut self,
198        left: &SqlExpression,
199        op: &str,
200        right: &SqlExpression,
201        row_index: usize,
202    ) -> Result<DataValue> {
203        let left_val = self.evaluate(left, row_index)?;
204        let right_val = self.evaluate(right, row_index)?;
205
206        debug!(
207            "ArithmeticEvaluator: {} {} {}",
208            self.format_value(&left_val),
209            op,
210            self.format_value(&right_val)
211        );
212
213        match op {
214            "+" => self.add_values(&left_val, &right_val),
215            "-" => self.subtract_values(&left_val, &right_val),
216            "*" => self.multiply_values(&left_val, &right_val),
217            "/" => self.divide_values(&left_val, &right_val),
218            "%" => {
219                // Modulo operator - call MOD function
220                let args = vec![left.clone(), right.clone()];
221                self.evaluate_function("MOD", &args, row_index)
222            }
223            // Comparison operators (return boolean results)
224            // Use centralized comparison logic for consistency
225            ">" | "<" | ">=" | "<=" | "=" | "!=" | "<>" => {
226                let result = compare_with_op(&left_val, &right_val, op, false);
227                Ok(DataValue::Boolean(result))
228            }
229            // IS NULL / IS NOT NULL operators
230            "IS NULL" => Ok(DataValue::Boolean(matches!(left_val, DataValue::Null))),
231            "IS NOT NULL" => Ok(DataValue::Boolean(!matches!(left_val, DataValue::Null))),
232            // Logical operators
233            "AND" => {
234                let left_bool = self.to_bool(&left_val)?;
235                let right_bool = self.to_bool(&right_val)?;
236                Ok(DataValue::Boolean(left_bool && right_bool))
237            }
238            "OR" => {
239                let left_bool = self.to_bool(&left_val)?;
240                let right_bool = self.to_bool(&right_val)?;
241                Ok(DataValue::Boolean(left_bool || right_bool))
242            }
243            _ => Err(anyhow!("Unsupported arithmetic operator: {}", op)),
244        }
245    }
246
247    /// Add two `DataValues` with type coercion
248    fn add_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
249        // NULL handling - any operation with NULL returns NULL
250        if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
251            return Ok(DataValue::Null);
252        }
253
254        match (left, right) {
255            (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a + b)),
256            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 + b)),
257            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a + *b as f64)),
258            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a + b)),
259            _ => Err(anyhow!("Cannot add {:?} and {:?}", left, right)),
260        }
261    }
262
263    /// Subtract two `DataValues` with type coercion
264    fn subtract_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
265        // NULL handling - any operation with NULL returns NULL
266        if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
267            return Ok(DataValue::Null);
268        }
269
270        match (left, right) {
271            (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a - b)),
272            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 - b)),
273            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a - *b as f64)),
274            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a - b)),
275            _ => Err(anyhow!("Cannot subtract {:?} and {:?}", left, right)),
276        }
277    }
278
279    /// Multiply two `DataValues` with type coercion
280    fn multiply_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
281        // NULL handling - any operation with NULL returns NULL
282        if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
283            return Ok(DataValue::Null);
284        }
285
286        match (left, right) {
287            (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a * b)),
288            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 * b)),
289            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a * *b as f64)),
290            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a * b)),
291            _ => Err(anyhow!("Cannot multiply {:?} and {:?}", left, right)),
292        }
293    }
294
295    /// Divide two `DataValues` with type coercion
296    fn divide_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
297        // NULL handling - any operation with NULL returns NULL
298        if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
299            return Ok(DataValue::Null);
300        }
301
302        // Check for division by zero first
303        let is_zero = match right {
304            DataValue::Integer(0) => true,
305            DataValue::Float(f) if *f == 0.0 => true, // Only check for exact zero, not epsilon
306            _ => false,
307        };
308
309        if is_zero {
310            return Err(anyhow!("Division by zero"));
311        }
312
313        match (left, right) {
314            (DataValue::Integer(a), DataValue::Integer(b)) => {
315                // Integer division - if result is exact, keep as int, otherwise promote to float
316                if a % b == 0 {
317                    Ok(DataValue::Integer(a / b))
318                } else {
319                    Ok(DataValue::Float(*a as f64 / *b as f64))
320                }
321            }
322            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 / b)),
323            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a / *b as f64)),
324            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a / b)),
325            _ => Err(anyhow!("Cannot divide {:?} and {:?}", left, right)),
326        }
327    }
328
329    /// Format a `DataValue` for debug output
330    fn format_value(&self, value: &DataValue) -> String {
331        match value {
332            DataValue::Integer(i) => i.to_string(),
333            DataValue::Float(f) => f.to_string(),
334            DataValue::String(s) => format!("'{s}'"),
335            _ => format!("{value:?}"),
336        }
337    }
338
339    /// Convert a DataValue to boolean for logical operations
340    fn to_bool(&self, value: &DataValue) -> Result<bool> {
341        match value {
342            DataValue::Boolean(b) => Ok(*b),
343            DataValue::Integer(i) => Ok(*i != 0),
344            DataValue::Float(f) => Ok(*f != 0.0),
345            DataValue::Null => Ok(false),
346            _ => Err(anyhow!("Cannot convert {:?} to boolean", value)),
347        }
348    }
349
350    /// Evaluate a function call
351    fn evaluate_function_with_distinct(
352        &mut self,
353        name: &str,
354        args: &[SqlExpression],
355        distinct: bool,
356        row_index: usize,
357    ) -> Result<DataValue> {
358        // If DISTINCT is specified, handle it specially for aggregate functions
359        if distinct {
360            let name_upper = name.to_uppercase();
361
362            // DISTINCT is only valid for aggregate functions
363            if name_upper == "COUNT"
364                || name_upper == "SUM"
365                || name_upper == "AVG"
366                || name_upper == "MIN"
367                || name_upper == "MAX"
368            {
369                return self.evaluate_aggregate_distinct(&name_upper, args, row_index);
370            } else {
371                return Err(anyhow!(
372                    "DISTINCT can only be used with aggregate functions"
373                ));
374            }
375        }
376
377        // Otherwise, use the regular evaluation
378        self.evaluate_function(name, args, row_index)
379    }
380
381    fn evaluate_aggregate_distinct(
382        &mut self,
383        name: &str,
384        args: &[SqlExpression],
385        _row_index: usize,
386    ) -> Result<DataValue> {
387        use std::collections::HashSet;
388
389        if args.is_empty() {
390            return Err(anyhow!("{} DISTINCT requires at least one argument", name));
391        }
392
393        // Determine which rows to process
394        let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
395            visible.clone()
396        } else {
397            (0..self.table.rows.len()).collect()
398        };
399
400        // Collect unique values
401        let mut unique_values = HashSet::new();
402        let mut numeric_values = Vec::new();
403
404        for row_idx in &rows_to_process {
405            // Evaluate the expression for this row
406            let value = self.evaluate(&args[0], *row_idx)?;
407
408            // Skip NULL values
409            if matches!(value, DataValue::Null) {
410                continue;
411            }
412
413            // Convert to string for uniqueness check
414            let value_str = match &value {
415                DataValue::String(s) => s.clone(),
416                DataValue::InternedString(s) => s.to_string(),
417                DataValue::Integer(i) => i.to_string(),
418                DataValue::Float(f) => f.to_string(),
419                DataValue::Boolean(b) => b.to_string(),
420                DataValue::DateTime(dt) => dt.to_string(),
421                DataValue::Null => continue,
422            };
423
424            // Only process if we haven't seen this value before
425            if unique_values.insert(value_str) {
426                // For numeric aggregates, collect the numeric value
427                if name != "COUNT" {
428                    match value {
429                        DataValue::Integer(i) => numeric_values.push(i as f64),
430                        DataValue::Float(f) => numeric_values.push(f),
431                        _ => {} // Skip non-numeric for SUM/AVG
432                    }
433                }
434            }
435        }
436
437        // Calculate the result based on the aggregate function
438        match name {
439            "COUNT" => Ok(DataValue::Integer(unique_values.len() as i64)),
440            "SUM" => {
441                if numeric_values.is_empty() {
442                    Ok(DataValue::Null)
443                } else {
444                    let sum: f64 = numeric_values.iter().sum();
445                    if sum.fract() == 0.0 && sum.abs() < 1e10 {
446                        Ok(DataValue::Integer(sum as i64))
447                    } else {
448                        Ok(DataValue::Float(sum))
449                    }
450                }
451            }
452            "AVG" => {
453                if numeric_values.is_empty() {
454                    Ok(DataValue::Null)
455                } else {
456                    let sum: f64 = numeric_values.iter().sum();
457                    Ok(DataValue::Float(sum / numeric_values.len() as f64))
458                }
459            }
460            "MIN" => {
461                if numeric_values.is_empty() {
462                    Ok(DataValue::Null)
463                } else {
464                    let min = numeric_values.iter().fold(f64::INFINITY, |a, &b| a.min(b));
465                    if min.fract() == 0.0 && min.abs() < 1e10 {
466                        Ok(DataValue::Integer(min as i64))
467                    } else {
468                        Ok(DataValue::Float(min))
469                    }
470                }
471            }
472            "MAX" => {
473                if numeric_values.is_empty() {
474                    Ok(DataValue::Null)
475                } else {
476                    let max = numeric_values
477                        .iter()
478                        .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
479                    if max.fract() == 0.0 && max.abs() < 1e10 {
480                        Ok(DataValue::Integer(max as i64))
481                    } else {
482                        Ok(DataValue::Float(max))
483                    }
484                }
485            }
486            _ => Err(anyhow!("Unsupported DISTINCT aggregate: {}", name)),
487        }
488    }
489
490    fn evaluate_function(
491        &mut self,
492        name: &str,
493        args: &[SqlExpression],
494        row_index: usize,
495    ) -> Result<DataValue> {
496        // Check if this is an aggregate function
497        let name_upper = name.to_uppercase();
498
499        // Handle COUNT(*) special case
500        if name_upper == "COUNT" && args.len() == 1 {
501            match &args[0] {
502                SqlExpression::Column(col) if col == "*" => {
503                    // COUNT(*) - count all rows (or visible rows if filtered)
504                    let count = if let Some(ref visible) = self.visible_rows {
505                        visible.len() as i64
506                    } else {
507                        self.table.rows.len() as i64
508                    };
509                    return Ok(DataValue::Integer(count));
510                }
511                SqlExpression::StringLiteral(s) if s == "*" => {
512                    // COUNT(*) parsed as StringLiteral
513                    let count = if let Some(ref visible) = self.visible_rows {
514                        visible.len() as i64
515                    } else {
516                        self.table.rows.len() as i64
517                    };
518                    return Ok(DataValue::Integer(count));
519                }
520                _ => {
521                    // COUNT(column) - will be handled below
522                }
523            }
524        }
525
526        // Check aggregate registry
527        if self.aggregate_registry.get(&name_upper).is_some() {
528            // Determine which rows to process first
529            let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
530                visible.clone()
531            } else {
532                (0..self.table.rows.len()).collect()
533            };
534
535            // Evaluate arguments first if needed (to avoid borrow issues)
536            let values = if !args.is_empty()
537                && !(args.len() == 1 && matches!(&args[0], SqlExpression::Column(c) if c == "*"))
538            {
539                // Evaluate the argument expression for each row
540                let mut vals = Vec::new();
541                for &row_idx in &rows_to_process {
542                    let value = self.evaluate(&args[0], row_idx)?;
543                    vals.push(value);
544                }
545                Some(vals)
546            } else {
547                None
548            };
549
550            // Now get the aggregate function and process
551            let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
552            let mut state = agg_func.init();
553
554            if let Some(values) = values {
555                // Use evaluated values
556                for value in &values {
557                    agg_func.accumulate(&mut state, value)?;
558                }
559            } else {
560                // COUNT(*) case
561                for _ in &rows_to_process {
562                    agg_func.accumulate(&mut state, &DataValue::Integer(1))?;
563                }
564            }
565
566            return Ok(agg_func.finalize(state));
567        }
568
569        // First check if this function exists in the registry
570        if self.function_registry.get(name).is_some() {
571            // Evaluate all arguments first to avoid borrow issues
572            let mut evaluated_args = Vec::new();
573            for arg in args {
574                evaluated_args.push(self.evaluate(arg, row_index)?);
575            }
576
577            // Get the function and call it
578            let func = self.function_registry.get(name).unwrap();
579            return func.evaluate(&evaluated_args);
580        }
581
582        // If not in registry, return error for unknown function
583        Err(anyhow!("Unknown function: {}", name))
584    }
585
586    /// Get or create a WindowContext for the given specification
587    fn get_or_create_window_context(&mut self, spec: &WindowSpec) -> Result<Arc<WindowContext>> {
588        // Create a key for caching based on the spec
589        let key = format!("{:?}", spec);
590
591        if let Some(context) = self.window_contexts.get(&key) {
592            return Ok(Arc::clone(context));
593        }
594
595        // Create a DataView from the table (with visible rows if filtered)
596        let data_view = if let Some(ref _visible_rows) = self.visible_rows {
597            // Create a filtered view
598            let view = DataView::new(Arc::new(self.table.clone()));
599            // Apply filtering based on visible rows
600            // Note: This is a simplified approach - in production we'd need proper filtering
601            view
602        } else {
603            DataView::new(Arc::new(self.table.clone()))
604        };
605
606        // Create the WindowContext
607        let context = WindowContext::new(
608            Arc::new(data_view),
609            spec.partition_by.clone(),
610            spec.order_by.clone(),
611        )?;
612
613        let context = Arc::new(context);
614        self.window_contexts.insert(key, Arc::clone(&context));
615        Ok(context)
616    }
617
618    /// Evaluate a window function
619    fn evaluate_window_function(
620        &mut self,
621        name: &str,
622        args: &[SqlExpression],
623        spec: &WindowSpec,
624        row_index: usize,
625    ) -> Result<DataValue> {
626        let context = self.get_or_create_window_context(spec)?;
627        let name_upper = name.to_uppercase();
628
629        match name_upper.as_str() {
630            "LAG" => {
631                // LAG(column, offset, default)
632                if args.is_empty() {
633                    return Err(anyhow!("LAG requires at least 1 argument"));
634                }
635
636                // Get column name
637                let column = match &args[0] {
638                    SqlExpression::Column(col) => col.clone(),
639                    _ => return Err(anyhow!("LAG first argument must be a column")),
640                };
641
642                // Get offset (default 1)
643                let offset = if args.len() > 1 {
644                    match self.evaluate(&args[1], row_index)? {
645                        DataValue::Integer(i) => i as i32,
646                        _ => return Err(anyhow!("LAG offset must be an integer")),
647                    }
648                } else {
649                    1
650                };
651
652                // Get value at offset
653                Ok(context
654                    .get_offset_value(row_index, -offset, &column)
655                    .unwrap_or(DataValue::Null))
656            }
657            "LEAD" => {
658                // LEAD(column, offset, default)
659                if args.is_empty() {
660                    return Err(anyhow!("LEAD requires at least 1 argument"));
661                }
662
663                // Get column name
664                let column = match &args[0] {
665                    SqlExpression::Column(col) => col.clone(),
666                    _ => return Err(anyhow!("LEAD first argument must be a column")),
667                };
668
669                // Get offset (default 1)
670                let offset = if args.len() > 1 {
671                    match self.evaluate(&args[1], row_index)? {
672                        DataValue::Integer(i) => i as i32,
673                        _ => return Err(anyhow!("LEAD offset must be an integer")),
674                    }
675                } else {
676                    1
677                };
678
679                // Get value at offset
680                Ok(context
681                    .get_offset_value(row_index, offset, &column)
682                    .unwrap_or(DataValue::Null))
683            }
684            "ROW_NUMBER" => {
685                // ROW_NUMBER() - no arguments
686                Ok(DataValue::Integer(context.get_row_number(row_index) as i64))
687            }
688            "FIRST_VALUE" => {
689                // FIRST_VALUE(column)
690                if args.is_empty() {
691                    return Err(anyhow!("FIRST_VALUE requires 1 argument"));
692                }
693
694                let column = match &args[0] {
695                    SqlExpression::Column(col) => col.clone(),
696                    _ => return Err(anyhow!("FIRST_VALUE argument must be a column")),
697                };
698
699                Ok(context
700                    .get_first_value(row_index, &column)
701                    .unwrap_or(DataValue::Null))
702            }
703            "LAST_VALUE" => {
704                // LAST_VALUE(column)
705                if args.is_empty() {
706                    return Err(anyhow!("LAST_VALUE requires 1 argument"));
707                }
708
709                let column = match &args[0] {
710                    SqlExpression::Column(col) => col.clone(),
711                    _ => return Err(anyhow!("LAST_VALUE argument must be a column")),
712                };
713
714                Ok(context
715                    .get_last_value(row_index, &column)
716                    .unwrap_or(DataValue::Null))
717            }
718            "SUM" => {
719                // SUM(column) OVER (PARTITION BY ...)
720                if args.is_empty() {
721                    return Err(anyhow!("SUM requires 1 argument"));
722                }
723
724                let column = match &args[0] {
725                    SqlExpression::Column(col) => col.clone(),
726                    _ => return Err(anyhow!("SUM argument must be a column")),
727                };
728
729                Ok(context
730                    .get_partition_sum(row_index, &column)
731                    .unwrap_or(DataValue::Null))
732            }
733            "COUNT" => {
734                // COUNT(*) or COUNT(column) OVER (PARTITION BY ...)
735                // Note: In window functions, COUNT(*) seems to come with no args
736                if args.is_empty() {
737                    // COUNT(*) OVER (...) - count all rows in partition
738                    Ok(context
739                        .get_partition_count(row_index, None)
740                        .unwrap_or(DataValue::Null))
741                } else {
742                    // Check for COUNT(*)
743                    let column = match &args[0] {
744                        SqlExpression::Column(col) => {
745                            if col == "*" {
746                                // COUNT(*) - count all rows in partition
747                                return Ok(context
748                                    .get_partition_count(row_index, None)
749                                    .unwrap_or(DataValue::Null));
750                            }
751                            col.clone()
752                        }
753                        SqlExpression::StringLiteral(s) if s == "*" => {
754                            // COUNT(*) as StringLiteral (how parser sends it)
755                            return Ok(context
756                                .get_partition_count(row_index, None)
757                                .unwrap_or(DataValue::Null));
758                        }
759                        _ => return Err(anyhow!("COUNT argument must be a column or *")),
760                    };
761
762                    // COUNT(column) - count non-null values
763                    Ok(context
764                        .get_partition_count(row_index, Some(&column))
765                        .unwrap_or(DataValue::Null))
766                }
767            }
768            _ => Err(anyhow!("Unknown window function: {}", name)),
769        }
770    }
771
772    /// Evaluate a method call on a column (e.g., `column.Trim()`)
773    fn evaluate_method_call(
774        &mut self,
775        object: &str,
776        method: &str,
777        args: &[SqlExpression],
778        row_index: usize,
779    ) -> Result<DataValue> {
780        // Get column value
781        let col_index = self.table.get_column_index(object).ok_or_else(|| {
782            let suggestion = self.find_similar_column(object);
783            match suggestion {
784                Some(similar) => {
785                    anyhow!("Column '{}' not found. Did you mean '{}'?", object, similar)
786                }
787                None => anyhow!("Column '{}' not found", object),
788            }
789        })?;
790
791        let cell_value = self.table.get_value(row_index, col_index).cloned();
792
793        self.evaluate_method_on_value(
794            &cell_value.unwrap_or(DataValue::Null),
795            method,
796            args,
797            row_index,
798        )
799    }
800
801    /// Evaluate a method on a value
802    fn evaluate_method_on_value(
803        &mut self,
804        value: &DataValue,
805        method: &str,
806        args: &[SqlExpression],
807        row_index: usize,
808    ) -> Result<DataValue> {
809        // First, try to proxy the method through the function registry
810        // Many string methods have corresponding functions (TRIM, LENGTH, CONTAINS, etc.)
811
812        // Map method names to function names (case-insensitive matching)
813        let function_name = match method.to_lowercase().as_str() {
814            "trim" => "TRIM",
815            "trimstart" | "trimbegin" => "TRIMSTART",
816            "trimend" => "TRIMEND",
817            "length" | "len" => "LENGTH",
818            "contains" => "CONTAINS",
819            "startswith" => "STARTSWITH",
820            "endswith" => "ENDSWITH",
821            "indexof" => "INDEXOF",
822            _ => method, // Try the method name as-is
823        };
824
825        // Check if we have this function in the registry
826        if self.function_registry.get(function_name).is_some() {
827            debug!(
828                "Proxying method '{}' through function registry as '{}'",
829                method, function_name
830            );
831
832            // Prepare arguments: receiver is the first argument, followed by method args
833            let mut func_args = vec![value.clone()];
834
835            // Evaluate method arguments and add them
836            for arg in args {
837                func_args.push(self.evaluate(arg, row_index)?);
838            }
839
840            // Get the function and call it
841            let func = self.function_registry.get(function_name).unwrap();
842            return func.evaluate(&func_args);
843        }
844
845        // If not in registry, fall back to the old implementation for compatibility
846        match method.to_lowercase().as_str() {
847            "trim" | "trimstart" | "trimend" => {
848                if !args.is_empty() {
849                    return Err(anyhow!("{} takes no arguments", method));
850                }
851
852                // Convert value to string and apply trim
853                let str_val = match value {
854                    DataValue::String(s) => s.clone(),
855                    DataValue::InternedString(s) => s.to_string(),
856                    DataValue::Integer(n) => n.to_string(),
857                    DataValue::Float(f) => f.to_string(),
858                    DataValue::Boolean(b) => b.to_string(),
859                    DataValue::DateTime(dt) => dt.clone(),
860                    DataValue::Null => return Ok(DataValue::Null),
861                };
862
863                let result = match method.to_lowercase().as_str() {
864                    "trim" => str_val.trim().to_string(),
865                    "trimstart" => str_val.trim_start().to_string(),
866                    "trimend" => str_val.trim_end().to_string(),
867                    _ => unreachable!(),
868                };
869
870                Ok(DataValue::String(result))
871            }
872            "length" => {
873                if !args.is_empty() {
874                    return Err(anyhow!("Length takes no arguments"));
875                }
876
877                // Get string length
878                let len = match value {
879                    DataValue::String(s) => s.len(),
880                    DataValue::InternedString(s) => s.len(),
881                    DataValue::Integer(n) => n.to_string().len(),
882                    DataValue::Float(f) => f.to_string().len(),
883                    DataValue::Boolean(b) => b.to_string().len(),
884                    DataValue::DateTime(dt) => dt.len(),
885                    DataValue::Null => return Ok(DataValue::Integer(0)),
886                };
887
888                Ok(DataValue::Integer(len as i64))
889            }
890            "indexof" => {
891                if args.len() != 1 {
892                    return Err(anyhow!("IndexOf requires exactly 1 argument"));
893                }
894
895                // Get the search string from args
896                let search_str = match self.evaluate(&args[0], row_index)? {
897                    DataValue::String(s) => s,
898                    DataValue::InternedString(s) => s.to_string(),
899                    DataValue::Integer(n) => n.to_string(),
900                    DataValue::Float(f) => f.to_string(),
901                    _ => return Err(anyhow!("IndexOf argument must be a string")),
902                };
903
904                // Convert value to string and find index
905                let str_val = match value {
906                    DataValue::String(s) => s.clone(),
907                    DataValue::InternedString(s) => s.to_string(),
908                    DataValue::Integer(n) => n.to_string(),
909                    DataValue::Float(f) => f.to_string(),
910                    DataValue::Boolean(b) => b.to_string(),
911                    DataValue::DateTime(dt) => dt.clone(),
912                    DataValue::Null => return Ok(DataValue::Integer(-1)),
913                };
914
915                let index = str_val.find(&search_str).map_or(-1, |i| i as i64);
916
917                Ok(DataValue::Integer(index))
918            }
919            "contains" => {
920                if args.len() != 1 {
921                    return Err(anyhow!("Contains requires exactly 1 argument"));
922                }
923
924                // Get the search string from args
925                let search_str = match self.evaluate(&args[0], row_index)? {
926                    DataValue::String(s) => s,
927                    DataValue::InternedString(s) => s.to_string(),
928                    DataValue::Integer(n) => n.to_string(),
929                    DataValue::Float(f) => f.to_string(),
930                    _ => return Err(anyhow!("Contains argument must be a string")),
931                };
932
933                // Convert value to string and check contains
934                let str_val = match value {
935                    DataValue::String(s) => s.clone(),
936                    DataValue::InternedString(s) => s.to_string(),
937                    DataValue::Integer(n) => n.to_string(),
938                    DataValue::Float(f) => f.to_string(),
939                    DataValue::Boolean(b) => b.to_string(),
940                    DataValue::DateTime(dt) => dt.clone(),
941                    DataValue::Null => return Ok(DataValue::Boolean(false)),
942                };
943
944                // Case-insensitive search
945                let result = str_val.to_lowercase().contains(&search_str.to_lowercase());
946                Ok(DataValue::Boolean(result))
947            }
948            "startswith" => {
949                if args.len() != 1 {
950                    return Err(anyhow!("StartsWith requires exactly 1 argument"));
951                }
952
953                // Get the prefix from args
954                let prefix = match self.evaluate(&args[0], row_index)? {
955                    DataValue::String(s) => s,
956                    DataValue::InternedString(s) => s.to_string(),
957                    DataValue::Integer(n) => n.to_string(),
958                    DataValue::Float(f) => f.to_string(),
959                    _ => return Err(anyhow!("StartsWith argument must be a string")),
960                };
961
962                // Convert value to string and check starts_with
963                let str_val = match value {
964                    DataValue::String(s) => s.clone(),
965                    DataValue::InternedString(s) => s.to_string(),
966                    DataValue::Integer(n) => n.to_string(),
967                    DataValue::Float(f) => f.to_string(),
968                    DataValue::Boolean(b) => b.to_string(),
969                    DataValue::DateTime(dt) => dt.clone(),
970                    DataValue::Null => return Ok(DataValue::Boolean(false)),
971                };
972
973                // Case-insensitive check
974                let result = str_val.to_lowercase().starts_with(&prefix.to_lowercase());
975                Ok(DataValue::Boolean(result))
976            }
977            "endswith" => {
978                if args.len() != 1 {
979                    return Err(anyhow!("EndsWith requires exactly 1 argument"));
980                }
981
982                // Get the suffix from args
983                let suffix = match self.evaluate(&args[0], row_index)? {
984                    DataValue::String(s) => s,
985                    DataValue::InternedString(s) => s.to_string(),
986                    DataValue::Integer(n) => n.to_string(),
987                    DataValue::Float(f) => f.to_string(),
988                    _ => return Err(anyhow!("EndsWith argument must be a string")),
989                };
990
991                // Convert value to string and check ends_with
992                let str_val = match value {
993                    DataValue::String(s) => s.clone(),
994                    DataValue::InternedString(s) => s.to_string(),
995                    DataValue::Integer(n) => n.to_string(),
996                    DataValue::Float(f) => f.to_string(),
997                    DataValue::Boolean(b) => b.to_string(),
998                    DataValue::DateTime(dt) => dt.clone(),
999                    DataValue::Null => return Ok(DataValue::Boolean(false)),
1000                };
1001
1002                // Case-insensitive check
1003                let result = str_val.to_lowercase().ends_with(&suffix.to_lowercase());
1004                Ok(DataValue::Boolean(result))
1005            }
1006            _ => Err(anyhow!("Unsupported method: {}", method)),
1007        }
1008    }
1009
1010    /// Evaluate a CASE expression
1011    fn evaluate_case_expression(
1012        &mut self,
1013        when_branches: &[crate::sql::recursive_parser::WhenBranch],
1014        else_branch: &Option<Box<SqlExpression>>,
1015        row_index: usize,
1016    ) -> Result<DataValue> {
1017        debug!(
1018            "ArithmeticEvaluator: evaluating CASE expression for row {}",
1019            row_index
1020        );
1021
1022        // Evaluate each WHEN condition in order
1023        for branch in when_branches {
1024            // Evaluate the condition as a boolean
1025            let condition_result = self.evaluate_condition_as_bool(&branch.condition, row_index)?;
1026
1027            if condition_result {
1028                debug!("CASE: WHEN condition matched, evaluating result expression");
1029                return self.evaluate(&branch.result, row_index);
1030            }
1031        }
1032
1033        // If no WHEN condition matched, evaluate ELSE clause (or return NULL)
1034        if let Some(else_expr) = else_branch {
1035            debug!("CASE: No WHEN matched, evaluating ELSE expression");
1036            self.evaluate(else_expr, row_index)
1037        } else {
1038            debug!("CASE: No WHEN matched and no ELSE, returning NULL");
1039            Ok(DataValue::Null)
1040        }
1041    }
1042
1043    /// Helper method to evaluate an expression as a boolean (for CASE WHEN conditions)
1044    fn evaluate_condition_as_bool(
1045        &mut self,
1046        expr: &SqlExpression,
1047        row_index: usize,
1048    ) -> Result<bool> {
1049        let value = self.evaluate(expr, row_index)?;
1050
1051        match value {
1052            DataValue::Boolean(b) => Ok(b),
1053            DataValue::Integer(i) => Ok(i != 0),
1054            DataValue::Float(f) => Ok(f != 0.0),
1055            DataValue::Null => Ok(false),
1056            DataValue::String(s) => Ok(!s.is_empty()),
1057            DataValue::InternedString(s) => Ok(!s.is_empty()),
1058            _ => Ok(true), // Other types are considered truthy
1059        }
1060    }
1061}
1062
1063#[cfg(test)]
1064mod tests {
1065    use super::*;
1066    use crate::data::datatable::{DataColumn, DataRow};
1067
1068    fn create_test_table() -> DataTable {
1069        let mut table = DataTable::new("test");
1070        table.add_column(DataColumn::new("a"));
1071        table.add_column(DataColumn::new("b"));
1072        table.add_column(DataColumn::new("c"));
1073
1074        table
1075            .add_row(DataRow::new(vec![
1076                DataValue::Integer(10),
1077                DataValue::Float(2.5),
1078                DataValue::Integer(4),
1079            ]))
1080            .unwrap();
1081
1082        table
1083    }
1084
1085    #[test]
1086    fn test_evaluate_column() {
1087        let table = create_test_table();
1088        let mut evaluator = ArithmeticEvaluator::new(&table);
1089
1090        let expr = SqlExpression::Column("a".to_string());
1091        let result = evaluator.evaluate(&expr, 0).unwrap();
1092        assert_eq!(result, DataValue::Integer(10));
1093    }
1094
1095    #[test]
1096    fn test_evaluate_number_literal() {
1097        let table = create_test_table();
1098        let mut evaluator = ArithmeticEvaluator::new(&table);
1099
1100        let expr = SqlExpression::NumberLiteral("42".to_string());
1101        let result = evaluator.evaluate(&expr, 0).unwrap();
1102        assert_eq!(result, DataValue::Integer(42));
1103
1104        let expr = SqlExpression::NumberLiteral("3.14".to_string());
1105        let result = evaluator.evaluate(&expr, 0).unwrap();
1106        assert_eq!(result, DataValue::Float(3.14));
1107    }
1108
1109    #[test]
1110    fn test_add_values() {
1111        let table = create_test_table();
1112        let mut evaluator = ArithmeticEvaluator::new(&table);
1113
1114        // Integer + Integer
1115        let result = evaluator
1116            .add_values(&DataValue::Integer(5), &DataValue::Integer(3))
1117            .unwrap();
1118        assert_eq!(result, DataValue::Integer(8));
1119
1120        // Integer + Float
1121        let result = evaluator
1122            .add_values(&DataValue::Integer(5), &DataValue::Float(2.5))
1123            .unwrap();
1124        assert_eq!(result, DataValue::Float(7.5));
1125    }
1126
1127    #[test]
1128    fn test_multiply_values() {
1129        let table = create_test_table();
1130        let mut evaluator = ArithmeticEvaluator::new(&table);
1131
1132        // Integer * Float
1133        let result = evaluator
1134            .multiply_values(&DataValue::Integer(4), &DataValue::Float(2.5))
1135            .unwrap();
1136        assert_eq!(result, DataValue::Float(10.0));
1137    }
1138
1139    #[test]
1140    fn test_divide_values() {
1141        let table = create_test_table();
1142        let mut evaluator = ArithmeticEvaluator::new(&table);
1143
1144        // Exact division
1145        let result = evaluator
1146            .divide_values(&DataValue::Integer(10), &DataValue::Integer(2))
1147            .unwrap();
1148        assert_eq!(result, DataValue::Integer(5));
1149
1150        // Non-exact division
1151        let result = evaluator
1152            .divide_values(&DataValue::Integer(10), &DataValue::Integer(3))
1153            .unwrap();
1154        assert_eq!(result, DataValue::Float(10.0 / 3.0));
1155    }
1156
1157    #[test]
1158    fn test_division_by_zero() {
1159        let table = create_test_table();
1160        let mut evaluator = ArithmeticEvaluator::new(&table);
1161
1162        let result = evaluator.divide_values(&DataValue::Integer(10), &DataValue::Integer(0));
1163        assert!(result.is_err());
1164        assert!(result.unwrap_err().to_string().contains("Division by zero"));
1165    }
1166
1167    #[test]
1168    fn test_binary_op_expression() {
1169        let table = create_test_table();
1170        let mut evaluator = ArithmeticEvaluator::new(&table);
1171
1172        // a * b where a=10, b=2.5
1173        let expr = SqlExpression::BinaryOp {
1174            left: Box::new(SqlExpression::Column("a".to_string())),
1175            op: "*".to_string(),
1176            right: Box::new(SqlExpression::Column("b".to_string())),
1177        };
1178
1179        let result = evaluator.evaluate(&expr, 0).unwrap();
1180        assert_eq!(result, DataValue::Float(25.0));
1181    }
1182}