sql_cli/data/
arithmetic_evaluator.rs

1use crate::data::data_view::DataView;
2use crate::data::datatable::{DataTable, DataValue};
3use crate::sql::aggregates::AggregateRegistry;
4use crate::sql::functions::FunctionRegistry;
5use crate::sql::recursive_parser::{SqlExpression, WindowSpec};
6use crate::sql::window_context::WindowContext;
7use anyhow::{anyhow, Result};
8use std::collections::HashMap;
9use std::sync::Arc;
10use tracing::debug;
11
12/// Evaluates SQL expressions to compute `DataValues` (for SELECT clauses)
13/// This is different from `RecursiveWhereEvaluator` which returns boolean
14pub struct ArithmeticEvaluator<'a> {
15    table: &'a DataTable,
16    date_notation: String,
17    function_registry: Arc<FunctionRegistry>,
18    aggregate_registry: Arc<AggregateRegistry>,
19    visible_rows: Option<Vec<usize>>, // For aggregate functions on filtered views
20    window_contexts: HashMap<String, Arc<WindowContext>>, // Cache window contexts by spec
21}
22
23impl<'a> ArithmeticEvaluator<'a> {
24    #[must_use]
25    pub fn new(table: &'a DataTable) -> Self {
26        Self {
27            table,
28            date_notation: "us".to_string(),
29            function_registry: Arc::new(FunctionRegistry::new()),
30            aggregate_registry: Arc::new(AggregateRegistry::new()),
31            visible_rows: None,
32            window_contexts: HashMap::new(),
33        }
34    }
35
36    #[must_use]
37    pub fn with_date_notation(table: &'a DataTable, date_notation: String) -> Self {
38        Self {
39            table,
40            date_notation,
41            function_registry: Arc::new(FunctionRegistry::new()),
42            aggregate_registry: Arc::new(AggregateRegistry::new()),
43            visible_rows: None,
44            window_contexts: HashMap::new(),
45        }
46    }
47
48    /// Set visible rows for aggregate functions (for filtered views)
49    #[must_use]
50    pub fn with_visible_rows(mut self, rows: Vec<usize>) -> Self {
51        self.visible_rows = Some(rows);
52        self
53    }
54
55    #[must_use]
56    pub fn with_date_notation_and_registry(
57        table: &'a DataTable,
58        date_notation: String,
59        function_registry: Arc<FunctionRegistry>,
60    ) -> Self {
61        Self {
62            table,
63            date_notation,
64            function_registry,
65            aggregate_registry: Arc::new(AggregateRegistry::new()),
66            visible_rows: None,
67            window_contexts: HashMap::new(),
68        }
69    }
70
71    /// Find a column name similar to the given name using edit distance
72    fn find_similar_column(&self, name: &str) -> Option<String> {
73        let columns = self.table.column_names();
74        let mut best_match: Option<(String, usize)> = None;
75
76        for col in columns {
77            let distance = self.edit_distance(&col.to_lowercase(), &name.to_lowercase());
78            // Only suggest if distance is small (likely a typo)
79            // Allow up to 3 edits for longer names
80            let max_distance = if name.len() > 10 { 3 } else { 2 };
81            if distance <= max_distance {
82                match &best_match {
83                    None => best_match = Some((col, distance)),
84                    Some((_, best_dist)) if distance < *best_dist => {
85                        best_match = Some((col, distance));
86                    }
87                    _ => {}
88                }
89            }
90        }
91
92        best_match.map(|(name, _)| name)
93    }
94
95    /// Calculate Levenshtein edit distance between two strings
96    fn edit_distance(&self, s1: &str, s2: &str) -> usize {
97        // Use the shared implementation from string_methods
98        crate::sql::functions::string_methods::EditDistanceFunction::calculate_edit_distance(s1, s2)
99    }
100
101    /// Evaluate an SQL expression to produce a `DataValue`
102    pub fn evaluate(&mut self, expr: &SqlExpression, row_index: usize) -> Result<DataValue> {
103        debug!(
104            "ArithmeticEvaluator: evaluating {:?} for row {}",
105            expr, row_index
106        );
107
108        match expr {
109            SqlExpression::Column(column_name) => self.evaluate_column(column_name, row_index),
110            SqlExpression::StringLiteral(s) => Ok(DataValue::String(s.clone())),
111            SqlExpression::BooleanLiteral(b) => Ok(DataValue::Boolean(*b)),
112            SqlExpression::NumberLiteral(n) => self.evaluate_number_literal(n),
113            SqlExpression::BinaryOp { left, op, right } => {
114                self.evaluate_binary_op(left, op, right, row_index)
115            }
116            SqlExpression::FunctionCall { name, args } => {
117                self.evaluate_function(name, args, row_index)
118            }
119            SqlExpression::WindowFunction {
120                name,
121                args,
122                window_spec,
123            } => self.evaluate_window_function(name, args, window_spec, row_index),
124            SqlExpression::MethodCall {
125                object,
126                method,
127                args,
128            } => self.evaluate_method_call(object, method, args, row_index),
129            SqlExpression::ChainedMethodCall { base, method, args } => {
130                // Evaluate the base expression first, then apply the method
131                let base_value = self.evaluate(base, row_index)?;
132                self.evaluate_method_on_value(&base_value, method, args, row_index)
133            }
134            SqlExpression::CaseExpression {
135                when_branches,
136                else_branch,
137            } => self.evaluate_case_expression(when_branches, else_branch, row_index),
138            _ => Err(anyhow!(
139                "Unsupported expression type for arithmetic evaluation: {:?}",
140                expr
141            )),
142        }
143    }
144
145    /// Evaluate a column reference
146    fn evaluate_column(&self, column_name: &str, row_index: usize) -> Result<DataValue> {
147        let col_index = self.table.get_column_index(column_name).ok_or_else(|| {
148            let suggestion = self.find_similar_column(column_name);
149            match suggestion {
150                Some(similar) => anyhow!(
151                    "Column '{}' not found. Did you mean '{}'?",
152                    column_name,
153                    similar
154                ),
155                None => anyhow!("Column '{}' not found", column_name),
156            }
157        })?;
158
159        if row_index >= self.table.row_count() {
160            return Err(anyhow!("Row index {} out of bounds", row_index));
161        }
162
163        let row = self
164            .table
165            .get_row(row_index)
166            .ok_or_else(|| anyhow!("Row {} not found", row_index))?;
167
168        let value = row
169            .get(col_index)
170            .ok_or_else(|| anyhow!("Column index {} out of bounds for row", col_index))?;
171
172        Ok(value.clone())
173    }
174
175    /// Evaluate a number literal (handles both integers and floats)
176    fn evaluate_number_literal(&self, number_str: &str) -> Result<DataValue> {
177        // Try to parse as integer first
178        if let Ok(int_val) = number_str.parse::<i64>() {
179            return Ok(DataValue::Integer(int_val));
180        }
181
182        // If that fails, try as float
183        if let Ok(float_val) = number_str.parse::<f64>() {
184            return Ok(DataValue::Float(float_val));
185        }
186
187        Err(anyhow!("Invalid number literal: {}", number_str))
188    }
189
190    /// Evaluate a binary operation (arithmetic)
191    fn evaluate_binary_op(
192        &mut self,
193        left: &SqlExpression,
194        op: &str,
195        right: &SqlExpression,
196        row_index: usize,
197    ) -> Result<DataValue> {
198        let left_val = self.evaluate(left, row_index)?;
199        let right_val = self.evaluate(right, row_index)?;
200
201        debug!(
202            "ArithmeticEvaluator: {} {} {}",
203            self.format_value(&left_val),
204            op,
205            self.format_value(&right_val)
206        );
207
208        match op {
209            "+" => self.add_values(&left_val, &right_val),
210            "-" => self.subtract_values(&left_val, &right_val),
211            "*" => self.multiply_values(&left_val, &right_val),
212            "/" => self.divide_values(&left_val, &right_val),
213            // Comparison operators (return boolean results)
214            ">" => self.compare_values(&left_val, &right_val, |a, b| a > b),
215            "<" => self.compare_values(&left_val, &right_val, |a, b| a < b),
216            ">=" => self.compare_values(&left_val, &right_val, |a, b| a >= b),
217            "<=" => self.compare_values(&left_val, &right_val, |a, b| a <= b),
218            "=" => self.compare_values(&left_val, &right_val, |a, b| a == b),
219            "!=" | "<>" => self.compare_values(&left_val, &right_val, |a, b| a != b),
220            _ => Err(anyhow!("Unsupported arithmetic operator: {}", op)),
221        }
222    }
223
224    /// Add two `DataValues` with type coercion
225    fn add_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
226        // NULL handling - any operation with NULL returns NULL
227        if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
228            return Ok(DataValue::Null);
229        }
230
231        match (left, right) {
232            (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a + b)),
233            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 + b)),
234            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a + *b as f64)),
235            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a + b)),
236            _ => Err(anyhow!("Cannot add {:?} and {:?}", left, right)),
237        }
238    }
239
240    /// Subtract two `DataValues` with type coercion
241    fn subtract_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
242        // NULL handling - any operation with NULL returns NULL
243        if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
244            return Ok(DataValue::Null);
245        }
246
247        match (left, right) {
248            (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a - b)),
249            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 - b)),
250            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a - *b as f64)),
251            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a - b)),
252            _ => Err(anyhow!("Cannot subtract {:?} and {:?}", left, right)),
253        }
254    }
255
256    /// Multiply two `DataValues` with type coercion
257    fn multiply_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
258        // NULL handling - any operation with NULL returns NULL
259        if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
260            return Ok(DataValue::Null);
261        }
262
263        match (left, right) {
264            (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a * b)),
265            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 * b)),
266            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a * *b as f64)),
267            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a * b)),
268            _ => Err(anyhow!("Cannot multiply {:?} and {:?}", left, right)),
269        }
270    }
271
272    /// Divide two `DataValues` with type coercion
273    fn divide_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
274        // NULL handling - any operation with NULL returns NULL
275        if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
276            return Ok(DataValue::Null);
277        }
278
279        // Check for division by zero first
280        let is_zero = match right {
281            DataValue::Integer(0) => true,
282            DataValue::Float(f) if *f == 0.0 => true, // Only check for exact zero, not epsilon
283            _ => false,
284        };
285
286        if is_zero {
287            return Err(anyhow!("Division by zero"));
288        }
289
290        match (left, right) {
291            (DataValue::Integer(a), DataValue::Integer(b)) => {
292                // Integer division - if result is exact, keep as int, otherwise promote to float
293                if a % b == 0 {
294                    Ok(DataValue::Integer(a / b))
295                } else {
296                    Ok(DataValue::Float(*a as f64 / *b as f64))
297                }
298            }
299            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 / b)),
300            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a / *b as f64)),
301            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a / b)),
302            _ => Err(anyhow!("Cannot divide {:?} and {:?}", left, right)),
303        }
304    }
305
306    /// Format a `DataValue` for debug output
307    fn format_value(&self, value: &DataValue) -> String {
308        match value {
309            DataValue::Integer(i) => i.to_string(),
310            DataValue::Float(f) => f.to_string(),
311            DataValue::String(s) => format!("'{s}'"),
312            _ => format!("{value:?}"),
313        }
314    }
315
316    /// Compare two `DataValues` using the provided comparison function
317    fn compare_values<F>(&self, left: &DataValue, right: &DataValue, op: F) -> Result<DataValue>
318    where
319        F: Fn(f64, f64) -> bool,
320    {
321        debug!(
322            "ArithmeticEvaluator: comparing values {:?} and {:?}",
323            left, right
324        );
325
326        let result = match (left, right) {
327            // Integer comparisons
328            (DataValue::Integer(a), DataValue::Integer(b)) => op(*a as f64, *b as f64),
329            (DataValue::Integer(a), DataValue::Float(b)) => op(*a as f64, *b),
330            (DataValue::Float(a), DataValue::Integer(b)) => op(*a, *b as f64),
331            (DataValue::Float(a), DataValue::Float(b)) => op(*a, *b),
332
333            // String comparisons (lexicographic)
334            (DataValue::String(a), DataValue::String(b)) => {
335                let a_num = a.parse::<f64>();
336                let b_num = b.parse::<f64>();
337                match (a_num, b_num) {
338                    (Ok(a_val), Ok(b_val)) => op(a_val, b_val), // Both are numbers
339                    _ => op(a.len() as f64, b.len() as f64),    // Fallback to length comparison
340                }
341            }
342            (DataValue::InternedString(a), DataValue::InternedString(b)) => {
343                let a_num = a.parse::<f64>();
344                let b_num = b.parse::<f64>();
345                match (a_num, b_num) {
346                    (Ok(a_val), Ok(b_val)) => op(a_val, b_val), // Both are numbers
347                    _ => op(a.len() as f64, b.len() as f64),    // Fallback to length comparison
348                }
349            }
350            (DataValue::String(a), DataValue::InternedString(b)) => {
351                let a_num = a.parse::<f64>();
352                let b_num = b.parse::<f64>();
353                match (a_num, b_num) {
354                    (Ok(a_val), Ok(b_val)) => op(a_val, b_val), // Both are numbers
355                    _ => op(a.len() as f64, b.len() as f64),    // Fallback to length comparison
356                }
357            }
358            (DataValue::InternedString(a), DataValue::String(b)) => {
359                let a_num = a.parse::<f64>();
360                let b_num = b.parse::<f64>();
361                match (a_num, b_num) {
362                    (Ok(a_val), Ok(b_val)) => op(a_val, b_val), // Both are numbers
363                    _ => op(a.len() as f64, b.len() as f64),    // Fallback to length comparison
364                }
365            }
366
367            // Mixed type comparisons (try to convert to numbers)
368            (DataValue::String(a), DataValue::Integer(b)) => {
369                match a.parse::<f64>() {
370                    Ok(a_val) => op(a_val, *b as f64),
371                    Err(_) => false, // String can't be compared with number
372                }
373            }
374            (DataValue::Integer(a), DataValue::String(b)) => {
375                match b.parse::<f64>() {
376                    Ok(b_val) => op(*a as f64, b_val),
377                    Err(_) => false, // String can't be compared with number
378                }
379            }
380            (DataValue::String(a), DataValue::Float(b)) => match a.parse::<f64>() {
381                Ok(a_val) => op(a_val, *b),
382                Err(_) => false,
383            },
384            (DataValue::Float(a), DataValue::String(b)) => match b.parse::<f64>() {
385                Ok(b_val) => op(*a, b_val),
386                Err(_) => false,
387            },
388
389            // NULL comparisons
390            (DataValue::Null, _) | (_, DataValue::Null) => false,
391
392            // Boolean comparisons
393            (DataValue::Boolean(a), DataValue::Boolean(b)) => {
394                op(if *a { 1.0 } else { 0.0 }, if *b { 1.0 } else { 0.0 })
395            }
396
397            _ => {
398                debug!(
399                    "ArithmeticEvaluator: unsupported comparison between {:?} and {:?}",
400                    left, right
401                );
402                false
403            }
404        };
405
406        debug!("ArithmeticEvaluator: comparison result: {}", result);
407        Ok(DataValue::Boolean(result))
408    }
409
410    /// Evaluate a function call
411    fn evaluate_function(
412        &mut self,
413        name: &str,
414        args: &[SqlExpression],
415        row_index: usize,
416    ) -> Result<DataValue> {
417        // Check if this is an aggregate function
418        let name_upper = name.to_uppercase();
419
420        // Handle COUNT(*) special case
421        if name_upper == "COUNT" && args.len() == 1 {
422            match &args[0] {
423                SqlExpression::Column(col) if col == "*" => {
424                    // COUNT(*) - count all rows (or visible rows if filtered)
425                    let count = if let Some(ref visible) = self.visible_rows {
426                        visible.len() as i64
427                    } else {
428                        self.table.rows.len() as i64
429                    };
430                    return Ok(DataValue::Integer(count));
431                }
432                SqlExpression::StringLiteral(s) if s == "*" => {
433                    // COUNT(*) parsed as StringLiteral
434                    let count = if let Some(ref visible) = self.visible_rows {
435                        visible.len() as i64
436                    } else {
437                        self.table.rows.len() as i64
438                    };
439                    return Ok(DataValue::Integer(count));
440                }
441                _ => {
442                    // COUNT(column) - will be handled below
443                }
444            }
445        }
446
447        // Check aggregate registry
448        if self.aggregate_registry.get(&name_upper).is_some() {
449            // Determine which rows to process first
450            let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
451                visible.clone()
452            } else {
453                (0..self.table.rows.len()).collect()
454            };
455
456            // Evaluate arguments first if needed (to avoid borrow issues)
457            let values = if !args.is_empty()
458                && !(args.len() == 1 && matches!(&args[0], SqlExpression::Column(c) if c == "*"))
459            {
460                // Evaluate the argument expression for each row
461                let mut vals = Vec::new();
462                for &row_idx in &rows_to_process {
463                    let value = self.evaluate(&args[0], row_idx)?;
464                    vals.push(value);
465                }
466                Some(vals)
467            } else {
468                None
469            };
470
471            // Now get the aggregate function and process
472            let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
473            let mut state = agg_func.init();
474
475            if let Some(values) = values {
476                // Use evaluated values
477                for value in &values {
478                    agg_func.accumulate(&mut state, value)?;
479                }
480            } else {
481                // COUNT(*) case
482                for _ in &rows_to_process {
483                    agg_func.accumulate(&mut state, &DataValue::Integer(1))?;
484                }
485            }
486
487            return Ok(agg_func.finalize(state));
488        }
489
490        // First check if this function exists in the registry
491        if self.function_registry.get(name).is_some() {
492            // Evaluate all arguments first to avoid borrow issues
493            let mut evaluated_args = Vec::new();
494            for arg in args {
495                evaluated_args.push(self.evaluate(arg, row_index)?);
496            }
497
498            // Get the function and call it
499            let func = self.function_registry.get(name).unwrap();
500            return func.evaluate(&evaluated_args);
501        }
502
503        // If not in registry, check built-in functions that need special handling
504        // (functions that need access to table data or row context)
505        // Convert function name to uppercase for case-insensitive matching
506        match name_upper.as_str() {
507            "CONVERT" => {
508                if args.len() != 3 {
509                    return Err(anyhow!(
510                        "CONVERT requires exactly 3 arguments: value, from_unit, to_unit"
511                    ));
512                }
513
514                // Evaluate the value
515                let value = self.evaluate(&args[0], row_index)?;
516                let numeric_value = match value {
517                    DataValue::Integer(n) => n as f64,
518                    DataValue::Float(f) => f,
519                    _ => return Err(anyhow!("CONVERT first argument must be numeric")),
520                };
521
522                // Get unit strings
523                let from_unit = match self.evaluate(&args[1], row_index)? {
524                    DataValue::String(s) => s,
525                    DataValue::InternedString(s) => s.to_string(),
526                    _ => {
527                        return Err(anyhow!(
528                            "CONVERT second argument must be a string (from_unit)"
529                        ))
530                    }
531                };
532
533                let to_unit = match self.evaluate(&args[2], row_index)? {
534                    DataValue::String(s) => s,
535                    DataValue::InternedString(s) => s.to_string(),
536                    _ => return Err(anyhow!("CONVERT third argument must be a string (to_unit)")),
537                };
538
539                // Perform conversion
540                match crate::data::unit_converter::convert_units(
541                    numeric_value,
542                    &from_unit,
543                    &to_unit,
544                ) {
545                    Ok(result) => Ok(DataValue::Float(result)),
546                    Err(e) => Err(anyhow!("Unit conversion error: {}", e)),
547                }
548            }
549            _ => Err(anyhow!("Unknown function: {}", name)),
550        }
551    }
552
553    /// Get or create a WindowContext for the given specification
554    fn get_or_create_window_context(&mut self, spec: &WindowSpec) -> Result<Arc<WindowContext>> {
555        // Create a key for caching based on the spec
556        let key = format!("{:?}", spec);
557
558        if let Some(context) = self.window_contexts.get(&key) {
559            return Ok(Arc::clone(context));
560        }
561
562        // Create a DataView from the table (with visible rows if filtered)
563        let data_view = if let Some(ref visible_rows) = self.visible_rows {
564            // Create a filtered view
565            let mut view = DataView::new(Arc::new(self.table.clone()));
566            // Apply filtering based on visible rows
567            // Note: This is a simplified approach - in production we'd need proper filtering
568            view
569        } else {
570            DataView::new(Arc::new(self.table.clone()))
571        };
572
573        // Create the WindowContext
574        let context = WindowContext::new(
575            Arc::new(data_view),
576            spec.partition_by.clone(),
577            spec.order_by.clone(),
578        )?;
579
580        let context = Arc::new(context);
581        self.window_contexts.insert(key, Arc::clone(&context));
582        Ok(context)
583    }
584
585    /// Evaluate a window function
586    fn evaluate_window_function(
587        &mut self,
588        name: &str,
589        args: &[SqlExpression],
590        spec: &WindowSpec,
591        row_index: usize,
592    ) -> Result<DataValue> {
593        let context = self.get_or_create_window_context(spec)?;
594        let name_upper = name.to_uppercase();
595
596        match name_upper.as_str() {
597            "LAG" => {
598                // LAG(column, offset, default)
599                if args.is_empty() {
600                    return Err(anyhow!("LAG requires at least 1 argument"));
601                }
602
603                // Get column name
604                let column = match &args[0] {
605                    SqlExpression::Column(col) => col.clone(),
606                    _ => return Err(anyhow!("LAG first argument must be a column")),
607                };
608
609                // Get offset (default 1)
610                let offset = if args.len() > 1 {
611                    match self.evaluate(&args[1], row_index)? {
612                        DataValue::Integer(i) => i as i32,
613                        _ => return Err(anyhow!("LAG offset must be an integer")),
614                    }
615                } else {
616                    1
617                };
618
619                // Get value at offset
620                Ok(context
621                    .get_offset_value(row_index, -offset, &column)
622                    .unwrap_or(DataValue::Null))
623            }
624            "LEAD" => {
625                // LEAD(column, offset, default)
626                if args.is_empty() {
627                    return Err(anyhow!("LEAD requires at least 1 argument"));
628                }
629
630                // Get column name
631                let column = match &args[0] {
632                    SqlExpression::Column(col) => col.clone(),
633                    _ => return Err(anyhow!("LEAD first argument must be a column")),
634                };
635
636                // Get offset (default 1)
637                let offset = if args.len() > 1 {
638                    match self.evaluate(&args[1], row_index)? {
639                        DataValue::Integer(i) => i as i32,
640                        _ => return Err(anyhow!("LEAD offset must be an integer")),
641                    }
642                } else {
643                    1
644                };
645
646                // Get value at offset
647                Ok(context
648                    .get_offset_value(row_index, offset, &column)
649                    .unwrap_or(DataValue::Null))
650            }
651            "ROW_NUMBER" => {
652                // ROW_NUMBER() - no arguments
653                Ok(DataValue::Integer(context.get_row_number(row_index) as i64))
654            }
655            "FIRST_VALUE" => {
656                // FIRST_VALUE(column)
657                if args.is_empty() {
658                    return Err(anyhow!("FIRST_VALUE requires 1 argument"));
659                }
660
661                let column = match &args[0] {
662                    SqlExpression::Column(col) => col.clone(),
663                    _ => return Err(anyhow!("FIRST_VALUE argument must be a column")),
664                };
665
666                Ok(context
667                    .get_first_value(row_index, &column)
668                    .unwrap_or(DataValue::Null))
669            }
670            "LAST_VALUE" => {
671                // LAST_VALUE(column)
672                if args.is_empty() {
673                    return Err(anyhow!("LAST_VALUE requires 1 argument"));
674                }
675
676                let column = match &args[0] {
677                    SqlExpression::Column(col) => col.clone(),
678                    _ => return Err(anyhow!("LAST_VALUE argument must be a column")),
679                };
680
681                Ok(context
682                    .get_last_value(row_index, &column)
683                    .unwrap_or(DataValue::Null))
684            }
685            _ => Err(anyhow!("Unknown window function: {}", name)),
686        }
687    }
688
689    /// Evaluate a method call on a column (e.g., `column.Trim()`)
690    fn evaluate_method_call(
691        &mut self,
692        object: &str,
693        method: &str,
694        args: &[SqlExpression],
695        row_index: usize,
696    ) -> Result<DataValue> {
697        // Get column value
698        let col_index = self.table.get_column_index(object).ok_or_else(|| {
699            let suggestion = self.find_similar_column(object);
700            match suggestion {
701                Some(similar) => {
702                    anyhow!("Column '{}' not found. Did you mean '{}'?", object, similar)
703                }
704                None => anyhow!("Column '{}' not found", object),
705            }
706        })?;
707
708        let cell_value = self.table.get_value(row_index, col_index).cloned();
709
710        self.evaluate_method_on_value(
711            &cell_value.unwrap_or(DataValue::Null),
712            method,
713            args,
714            row_index,
715        )
716    }
717
718    /// Evaluate a method on a value
719    fn evaluate_method_on_value(
720        &mut self,
721        value: &DataValue,
722        method: &str,
723        args: &[SqlExpression],
724        row_index: usize,
725    ) -> Result<DataValue> {
726        // First, try to proxy the method through the function registry
727        // Many string methods have corresponding functions (TRIM, LENGTH, CONTAINS, etc.)
728
729        // Map method names to function names (case-insensitive matching)
730        let function_name = match method.to_lowercase().as_str() {
731            "trim" => "TRIM",
732            "trimstart" | "trimbegin" => "TRIMSTART",
733            "trimend" => "TRIMEND",
734            "length" | "len" => "LENGTH",
735            "contains" => "CONTAINS",
736            "startswith" => "STARTSWITH",
737            "endswith" => "ENDSWITH",
738            "indexof" => "INDEXOF",
739            _ => method, // Try the method name as-is
740        };
741
742        // Check if we have this function in the registry
743        if self.function_registry.get(function_name).is_some() {
744            debug!(
745                "Proxying method '{}' through function registry as '{}'",
746                method, function_name
747            );
748
749            // Prepare arguments: receiver is the first argument, followed by method args
750            let mut func_args = vec![value.clone()];
751
752            // Evaluate method arguments and add them
753            for arg in args {
754                func_args.push(self.evaluate(arg, row_index)?);
755            }
756
757            // Get the function and call it
758            let func = self.function_registry.get(function_name).unwrap();
759            return func.evaluate(&func_args);
760        }
761
762        // If not in registry, fall back to the old implementation for compatibility
763        match method.to_lowercase().as_str() {
764            "trim" | "trimstart" | "trimend" => {
765                if !args.is_empty() {
766                    return Err(anyhow!("{} takes no arguments", method));
767                }
768
769                // Convert value to string and apply trim
770                let str_val = match value {
771                    DataValue::String(s) => s.clone(),
772                    DataValue::InternedString(s) => s.to_string(),
773                    DataValue::Integer(n) => n.to_string(),
774                    DataValue::Float(f) => f.to_string(),
775                    DataValue::Boolean(b) => b.to_string(),
776                    DataValue::DateTime(dt) => dt.clone(),
777                    DataValue::Null => return Ok(DataValue::Null),
778                };
779
780                let result = match method.to_lowercase().as_str() {
781                    "trim" => str_val.trim().to_string(),
782                    "trimstart" => str_val.trim_start().to_string(),
783                    "trimend" => str_val.trim_end().to_string(),
784                    _ => unreachable!(),
785                };
786
787                Ok(DataValue::String(result))
788            }
789            "length" => {
790                if !args.is_empty() {
791                    return Err(anyhow!("Length takes no arguments"));
792                }
793
794                // Get string length
795                let len = match value {
796                    DataValue::String(s) => s.len(),
797                    DataValue::InternedString(s) => s.len(),
798                    DataValue::Integer(n) => n.to_string().len(),
799                    DataValue::Float(f) => f.to_string().len(),
800                    DataValue::Boolean(b) => b.to_string().len(),
801                    DataValue::DateTime(dt) => dt.len(),
802                    DataValue::Null => return Ok(DataValue::Integer(0)),
803                };
804
805                Ok(DataValue::Integer(len as i64))
806            }
807            "indexof" => {
808                if args.len() != 1 {
809                    return Err(anyhow!("IndexOf requires exactly 1 argument"));
810                }
811
812                // Get the search string from args
813                let search_str = match self.evaluate(&args[0], row_index)? {
814                    DataValue::String(s) => s,
815                    DataValue::InternedString(s) => s.to_string(),
816                    DataValue::Integer(n) => n.to_string(),
817                    DataValue::Float(f) => f.to_string(),
818                    _ => return Err(anyhow!("IndexOf argument must be a string")),
819                };
820
821                // Convert value to string and find index
822                let str_val = match value {
823                    DataValue::String(s) => s.clone(),
824                    DataValue::InternedString(s) => s.to_string(),
825                    DataValue::Integer(n) => n.to_string(),
826                    DataValue::Float(f) => f.to_string(),
827                    DataValue::Boolean(b) => b.to_string(),
828                    DataValue::DateTime(dt) => dt.clone(),
829                    DataValue::Null => return Ok(DataValue::Integer(-1)),
830                };
831
832                let index = str_val.find(&search_str).map_or(-1, |i| i as i64);
833
834                Ok(DataValue::Integer(index))
835            }
836            "contains" => {
837                if args.len() != 1 {
838                    return Err(anyhow!("Contains requires exactly 1 argument"));
839                }
840
841                // Get the search string from args
842                let search_str = match self.evaluate(&args[0], row_index)? {
843                    DataValue::String(s) => s,
844                    DataValue::InternedString(s) => s.to_string(),
845                    DataValue::Integer(n) => n.to_string(),
846                    DataValue::Float(f) => f.to_string(),
847                    _ => return Err(anyhow!("Contains argument must be a string")),
848                };
849
850                // Convert value to string and check contains
851                let str_val = match value {
852                    DataValue::String(s) => s.clone(),
853                    DataValue::InternedString(s) => s.to_string(),
854                    DataValue::Integer(n) => n.to_string(),
855                    DataValue::Float(f) => f.to_string(),
856                    DataValue::Boolean(b) => b.to_string(),
857                    DataValue::DateTime(dt) => dt.clone(),
858                    DataValue::Null => return Ok(DataValue::Boolean(false)),
859                };
860
861                // Case-insensitive search
862                let result = str_val.to_lowercase().contains(&search_str.to_lowercase());
863                Ok(DataValue::Boolean(result))
864            }
865            "startswith" => {
866                if args.len() != 1 {
867                    return Err(anyhow!("StartsWith requires exactly 1 argument"));
868                }
869
870                // Get the prefix from args
871                let prefix = match self.evaluate(&args[0], row_index)? {
872                    DataValue::String(s) => s,
873                    DataValue::InternedString(s) => s.to_string(),
874                    DataValue::Integer(n) => n.to_string(),
875                    DataValue::Float(f) => f.to_string(),
876                    _ => return Err(anyhow!("StartsWith argument must be a string")),
877                };
878
879                // Convert value to string and check starts_with
880                let str_val = match value {
881                    DataValue::String(s) => s.clone(),
882                    DataValue::InternedString(s) => s.to_string(),
883                    DataValue::Integer(n) => n.to_string(),
884                    DataValue::Float(f) => f.to_string(),
885                    DataValue::Boolean(b) => b.to_string(),
886                    DataValue::DateTime(dt) => dt.clone(),
887                    DataValue::Null => return Ok(DataValue::Boolean(false)),
888                };
889
890                // Case-insensitive check
891                let result = str_val.to_lowercase().starts_with(&prefix.to_lowercase());
892                Ok(DataValue::Boolean(result))
893            }
894            "endswith" => {
895                if args.len() != 1 {
896                    return Err(anyhow!("EndsWith requires exactly 1 argument"));
897                }
898
899                // Get the suffix from args
900                let suffix = match self.evaluate(&args[0], row_index)? {
901                    DataValue::String(s) => s,
902                    DataValue::InternedString(s) => s.to_string(),
903                    DataValue::Integer(n) => n.to_string(),
904                    DataValue::Float(f) => f.to_string(),
905                    _ => return Err(anyhow!("EndsWith argument must be a string")),
906                };
907
908                // Convert value to string and check ends_with
909                let str_val = match value {
910                    DataValue::String(s) => s.clone(),
911                    DataValue::InternedString(s) => s.to_string(),
912                    DataValue::Integer(n) => n.to_string(),
913                    DataValue::Float(f) => f.to_string(),
914                    DataValue::Boolean(b) => b.to_string(),
915                    DataValue::DateTime(dt) => dt.clone(),
916                    DataValue::Null => return Ok(DataValue::Boolean(false)),
917                };
918
919                // Case-insensitive check
920                let result = str_val.to_lowercase().ends_with(&suffix.to_lowercase());
921                Ok(DataValue::Boolean(result))
922            }
923            _ => Err(anyhow!("Unsupported method: {}", method)),
924        }
925    }
926
927    /// Evaluate a CASE expression
928    fn evaluate_case_expression(
929        &mut self,
930        when_branches: &[crate::sql::recursive_parser::WhenBranch],
931        else_branch: &Option<Box<SqlExpression>>,
932        row_index: usize,
933    ) -> Result<DataValue> {
934        debug!(
935            "ArithmeticEvaluator: evaluating CASE expression for row {}",
936            row_index
937        );
938
939        // Evaluate each WHEN condition in order
940        for branch in when_branches {
941            // Evaluate the condition as a boolean
942            let condition_result = self.evaluate_condition_as_bool(&branch.condition, row_index)?;
943
944            if condition_result {
945                debug!("CASE: WHEN condition matched, evaluating result expression");
946                return self.evaluate(&branch.result, row_index);
947            }
948        }
949
950        // If no WHEN condition matched, evaluate ELSE clause (or return NULL)
951        if let Some(else_expr) = else_branch {
952            debug!("CASE: No WHEN matched, evaluating ELSE expression");
953            self.evaluate(else_expr, row_index)
954        } else {
955            debug!("CASE: No WHEN matched and no ELSE, returning NULL");
956            Ok(DataValue::Null)
957        }
958    }
959
960    /// Helper method to evaluate an expression as a boolean (for CASE WHEN conditions)
961    fn evaluate_condition_as_bool(
962        &mut self,
963        expr: &SqlExpression,
964        row_index: usize,
965    ) -> Result<bool> {
966        let value = self.evaluate(expr, row_index)?;
967
968        match value {
969            DataValue::Boolean(b) => Ok(b),
970            DataValue::Integer(i) => Ok(i != 0),
971            DataValue::Float(f) => Ok(f != 0.0),
972            DataValue::Null => Ok(false),
973            DataValue::String(s) => Ok(!s.is_empty()),
974            DataValue::InternedString(s) => Ok(!s.is_empty()),
975            _ => Ok(true), // Other types are considered truthy
976        }
977    }
978}
979
980#[cfg(test)]
981mod tests {
982    use super::*;
983    use crate::data::datatable::{DataColumn, DataRow};
984
985    fn create_test_table() -> DataTable {
986        let mut table = DataTable::new("test");
987        table.add_column(DataColumn::new("a"));
988        table.add_column(DataColumn::new("b"));
989        table.add_column(DataColumn::new("c"));
990
991        table
992            .add_row(DataRow::new(vec![
993                DataValue::Integer(10),
994                DataValue::Float(2.5),
995                DataValue::Integer(4),
996            ]))
997            .unwrap();
998
999        table
1000    }
1001
1002    #[test]
1003    fn test_evaluate_column() {
1004        let table = create_test_table();
1005        let mut evaluator = ArithmeticEvaluator::new(&table);
1006
1007        let expr = SqlExpression::Column("a".to_string());
1008        let result = evaluator.evaluate(&expr, 0).unwrap();
1009        assert_eq!(result, DataValue::Integer(10));
1010    }
1011
1012    #[test]
1013    fn test_evaluate_number_literal() {
1014        let table = create_test_table();
1015        let mut evaluator = ArithmeticEvaluator::new(&table);
1016
1017        let expr = SqlExpression::NumberLiteral("42".to_string());
1018        let result = evaluator.evaluate(&expr, 0).unwrap();
1019        assert_eq!(result, DataValue::Integer(42));
1020
1021        let expr = SqlExpression::NumberLiteral("3.14".to_string());
1022        let result = evaluator.evaluate(&expr, 0).unwrap();
1023        assert_eq!(result, DataValue::Float(3.14));
1024    }
1025
1026    #[test]
1027    fn test_add_values() {
1028        let table = create_test_table();
1029        let mut evaluator = ArithmeticEvaluator::new(&table);
1030
1031        // Integer + Integer
1032        let result = evaluator
1033            .add_values(&DataValue::Integer(5), &DataValue::Integer(3))
1034            .unwrap();
1035        assert_eq!(result, DataValue::Integer(8));
1036
1037        // Integer + Float
1038        let result = evaluator
1039            .add_values(&DataValue::Integer(5), &DataValue::Float(2.5))
1040            .unwrap();
1041        assert_eq!(result, DataValue::Float(7.5));
1042    }
1043
1044    #[test]
1045    fn test_multiply_values() {
1046        let table = create_test_table();
1047        let mut evaluator = ArithmeticEvaluator::new(&table);
1048
1049        // Integer * Float
1050        let result = evaluator
1051            .multiply_values(&DataValue::Integer(4), &DataValue::Float(2.5))
1052            .unwrap();
1053        assert_eq!(result, DataValue::Float(10.0));
1054    }
1055
1056    #[test]
1057    fn test_divide_values() {
1058        let table = create_test_table();
1059        let mut evaluator = ArithmeticEvaluator::new(&table);
1060
1061        // Exact division
1062        let result = evaluator
1063            .divide_values(&DataValue::Integer(10), &DataValue::Integer(2))
1064            .unwrap();
1065        assert_eq!(result, DataValue::Integer(5));
1066
1067        // Non-exact division
1068        let result = evaluator
1069            .divide_values(&DataValue::Integer(10), &DataValue::Integer(3))
1070            .unwrap();
1071        assert_eq!(result, DataValue::Float(10.0 / 3.0));
1072    }
1073
1074    #[test]
1075    fn test_division_by_zero() {
1076        let table = create_test_table();
1077        let mut evaluator = ArithmeticEvaluator::new(&table);
1078
1079        let result = evaluator.divide_values(&DataValue::Integer(10), &DataValue::Integer(0));
1080        assert!(result.is_err());
1081        assert!(result.unwrap_err().to_string().contains("Division by zero"));
1082    }
1083
1084    #[test]
1085    fn test_binary_op_expression() {
1086        let table = create_test_table();
1087        let mut evaluator = ArithmeticEvaluator::new(&table);
1088
1089        // a * b where a=10, b=2.5
1090        let expr = SqlExpression::BinaryOp {
1091            left: Box::new(SqlExpression::Column("a".to_string())),
1092            op: "*".to_string(),
1093            right: Box::new(SqlExpression::Column("b".to_string())),
1094        };
1095
1096        let result = evaluator.evaluate(&expr, 0).unwrap();
1097        assert_eq!(result, DataValue::Float(25.0));
1098    }
1099}