sql_cli/data/
arithmetic_evaluator.rs

1use crate::data::data_view::DataView;
2use crate::data::datatable::{DataTable, DataValue};
3use crate::sql::aggregates::AggregateRegistry;
4use crate::sql::functions::FunctionRegistry;
5use crate::sql::recursive_parser::{SqlExpression, WindowSpec};
6use crate::sql::window_context::WindowContext;
7use anyhow::{anyhow, Result};
8use std::collections::HashMap;
9use std::sync::Arc;
10use tracing::debug;
11
12/// Evaluates SQL expressions to compute `DataValues` (for SELECT clauses)
13/// This is different from `RecursiveWhereEvaluator` which returns boolean
14pub struct ArithmeticEvaluator<'a> {
15    table: &'a DataTable,
16    date_notation: String,
17    function_registry: Arc<FunctionRegistry>,
18    aggregate_registry: Arc<AggregateRegistry>,
19    visible_rows: Option<Vec<usize>>, // For aggregate functions on filtered views
20    window_contexts: HashMap<String, Arc<WindowContext>>, // Cache window contexts by spec
21}
22
23impl<'a> ArithmeticEvaluator<'a> {
24    #[must_use]
25    pub fn new(table: &'a DataTable) -> Self {
26        Self {
27            table,
28            date_notation: "us".to_string(),
29            function_registry: Arc::new(FunctionRegistry::new()),
30            aggregate_registry: Arc::new(AggregateRegistry::new()),
31            visible_rows: None,
32            window_contexts: HashMap::new(),
33        }
34    }
35
36    #[must_use]
37    pub fn with_date_notation(table: &'a DataTable, date_notation: String) -> Self {
38        Self {
39            table,
40            date_notation,
41            function_registry: Arc::new(FunctionRegistry::new()),
42            aggregate_registry: Arc::new(AggregateRegistry::new()),
43            visible_rows: None,
44            window_contexts: HashMap::new(),
45        }
46    }
47
48    /// Set visible rows for aggregate functions (for filtered views)
49    #[must_use]
50    pub fn with_visible_rows(mut self, rows: Vec<usize>) -> Self {
51        self.visible_rows = Some(rows);
52        self
53    }
54
55    #[must_use]
56    pub fn with_date_notation_and_registry(
57        table: &'a DataTable,
58        date_notation: String,
59        function_registry: Arc<FunctionRegistry>,
60    ) -> Self {
61        Self {
62            table,
63            date_notation,
64            function_registry,
65            aggregate_registry: Arc::new(AggregateRegistry::new()),
66            visible_rows: None,
67            window_contexts: HashMap::new(),
68        }
69    }
70
71    /// Find a column name similar to the given name using edit distance
72    fn find_similar_column(&self, name: &str) -> Option<String> {
73        let columns = self.table.column_names();
74        let mut best_match: Option<(String, usize)> = None;
75
76        for col in columns {
77            let distance = self.edit_distance(&col.to_lowercase(), &name.to_lowercase());
78            // Only suggest if distance is small (likely a typo)
79            // Allow up to 3 edits for longer names
80            let max_distance = if name.len() > 10 { 3 } else { 2 };
81            if distance <= max_distance {
82                match &best_match {
83                    None => best_match = Some((col, distance)),
84                    Some((_, best_dist)) if distance < *best_dist => {
85                        best_match = Some((col, distance));
86                    }
87                    _ => {}
88                }
89            }
90        }
91
92        best_match.map(|(name, _)| name)
93    }
94
95    /// Calculate Levenshtein edit distance between two strings
96    fn edit_distance(&self, s1: &str, s2: &str) -> usize {
97        // Use the shared implementation from string_methods
98        crate::sql::functions::string_methods::EditDistanceFunction::calculate_edit_distance(s1, s2)
99    }
100
101    /// Evaluate an SQL expression to produce a `DataValue`
102    pub fn evaluate(&mut self, expr: &SqlExpression, row_index: usize) -> Result<DataValue> {
103        debug!(
104            "ArithmeticEvaluator: evaluating {:?} for row {}",
105            expr, row_index
106        );
107
108        match expr {
109            SqlExpression::Column(column_name) => self.evaluate_column(column_name, row_index),
110            SqlExpression::StringLiteral(s) => Ok(DataValue::String(s.clone())),
111            SqlExpression::BooleanLiteral(b) => Ok(DataValue::Boolean(*b)),
112            SqlExpression::NumberLiteral(n) => self.evaluate_number_literal(n),
113            SqlExpression::Null => Ok(DataValue::Null),
114            SqlExpression::BinaryOp { left, op, right } => {
115                self.evaluate_binary_op(left, op, right, row_index)
116            }
117            SqlExpression::FunctionCall {
118                name,
119                args,
120                distinct,
121            } => self.evaluate_function_with_distinct(name, args, *distinct, row_index),
122            SqlExpression::WindowFunction {
123                name,
124                args,
125                window_spec,
126            } => self.evaluate_window_function(name, args, window_spec, row_index),
127            SqlExpression::MethodCall {
128                object,
129                method,
130                args,
131            } => self.evaluate_method_call(object, method, args, row_index),
132            SqlExpression::ChainedMethodCall { base, method, args } => {
133                // Evaluate the base expression first, then apply the method
134                let base_value = self.evaluate(base, row_index)?;
135                self.evaluate_method_on_value(&base_value, method, args, row_index)
136            }
137            SqlExpression::CaseExpression {
138                when_branches,
139                else_branch,
140            } => self.evaluate_case_expression(when_branches, else_branch, row_index),
141            _ => Err(anyhow!(
142                "Unsupported expression type for arithmetic evaluation: {:?}",
143                expr
144            )),
145        }
146    }
147
148    /// Evaluate a column reference
149    fn evaluate_column(&self, column_name: &str, row_index: usize) -> Result<DataValue> {
150        let col_index = self.table.get_column_index(column_name).ok_or_else(|| {
151            let suggestion = self.find_similar_column(column_name);
152            match suggestion {
153                Some(similar) => anyhow!(
154                    "Column '{}' not found. Did you mean '{}'?",
155                    column_name,
156                    similar
157                ),
158                None => anyhow!("Column '{}' not found", column_name),
159            }
160        })?;
161
162        if row_index >= self.table.row_count() {
163            return Err(anyhow!("Row index {} out of bounds", row_index));
164        }
165
166        let row = self
167            .table
168            .get_row(row_index)
169            .ok_or_else(|| anyhow!("Row {} not found", row_index))?;
170
171        let value = row
172            .get(col_index)
173            .ok_or_else(|| anyhow!("Column index {} out of bounds for row", col_index))?;
174
175        Ok(value.clone())
176    }
177
178    /// Evaluate a number literal (handles both integers and floats)
179    fn evaluate_number_literal(&self, number_str: &str) -> Result<DataValue> {
180        // Try to parse as integer first
181        if let Ok(int_val) = number_str.parse::<i64>() {
182            return Ok(DataValue::Integer(int_val));
183        }
184
185        // If that fails, try as float
186        if let Ok(float_val) = number_str.parse::<f64>() {
187            return Ok(DataValue::Float(float_val));
188        }
189
190        Err(anyhow!("Invalid number literal: {}", number_str))
191    }
192
193    /// Evaluate a binary operation (arithmetic)
194    fn evaluate_binary_op(
195        &mut self,
196        left: &SqlExpression,
197        op: &str,
198        right: &SqlExpression,
199        row_index: usize,
200    ) -> Result<DataValue> {
201        let left_val = self.evaluate(left, row_index)?;
202        let right_val = self.evaluate(right, row_index)?;
203
204        debug!(
205            "ArithmeticEvaluator: {} {} {}",
206            self.format_value(&left_val),
207            op,
208            self.format_value(&right_val)
209        );
210
211        match op {
212            "+" => self.add_values(&left_val, &right_val),
213            "-" => self.subtract_values(&left_val, &right_val),
214            "*" => self.multiply_values(&left_val, &right_val),
215            "/" => self.divide_values(&left_val, &right_val),
216            "%" => {
217                // Modulo operator - call MOD function
218                let args = vec![left.clone(), right.clone()];
219                self.evaluate_function("MOD", &args, row_index)
220            }
221            // Comparison operators (return boolean results)
222            ">" => self.compare_values(&left_val, &right_val, |a, b| a > b),
223            "<" => self.compare_values(&left_val, &right_val, |a, b| a < b),
224            ">=" => self.compare_values(&left_val, &right_val, |a, b| a >= b),
225            "<=" => self.compare_values(&left_val, &right_val, |a, b| a <= b),
226            "=" => self.compare_values(&left_val, &right_val, |a, b| a == b),
227            "!=" | "<>" => self.compare_values(&left_val, &right_val, |a, b| a != b),
228            // IS NULL / IS NOT NULL operators
229            "IS NULL" => Ok(DataValue::Boolean(matches!(left_val, DataValue::Null))),
230            "IS NOT NULL" => Ok(DataValue::Boolean(!matches!(left_val, DataValue::Null))),
231            _ => Err(anyhow!("Unsupported arithmetic operator: {}", op)),
232        }
233    }
234
235    /// Add two `DataValues` with type coercion
236    fn add_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
237        // NULL handling - any operation with NULL returns NULL
238        if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
239            return Ok(DataValue::Null);
240        }
241
242        match (left, right) {
243            (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a + b)),
244            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 + b)),
245            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a + *b as f64)),
246            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a + b)),
247            _ => Err(anyhow!("Cannot add {:?} and {:?}", left, right)),
248        }
249    }
250
251    /// Subtract two `DataValues` with type coercion
252    fn subtract_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
253        // NULL handling - any operation with NULL returns NULL
254        if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
255            return Ok(DataValue::Null);
256        }
257
258        match (left, right) {
259            (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a - b)),
260            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 - b)),
261            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a - *b as f64)),
262            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a - b)),
263            _ => Err(anyhow!("Cannot subtract {:?} and {:?}", left, right)),
264        }
265    }
266
267    /// Multiply two `DataValues` with type coercion
268    fn multiply_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
269        // NULL handling - any operation with NULL returns NULL
270        if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
271            return Ok(DataValue::Null);
272        }
273
274        match (left, right) {
275            (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a * b)),
276            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 * b)),
277            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a * *b as f64)),
278            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a * b)),
279            _ => Err(anyhow!("Cannot multiply {:?} and {:?}", left, right)),
280        }
281    }
282
283    /// Divide two `DataValues` with type coercion
284    fn divide_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
285        // NULL handling - any operation with NULL returns NULL
286        if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
287            return Ok(DataValue::Null);
288        }
289
290        // Check for division by zero first
291        let is_zero = match right {
292            DataValue::Integer(0) => true,
293            DataValue::Float(f) if *f == 0.0 => true, // Only check for exact zero, not epsilon
294            _ => false,
295        };
296
297        if is_zero {
298            return Err(anyhow!("Division by zero"));
299        }
300
301        match (left, right) {
302            (DataValue::Integer(a), DataValue::Integer(b)) => {
303                // Integer division - if result is exact, keep as int, otherwise promote to float
304                if a % b == 0 {
305                    Ok(DataValue::Integer(a / b))
306                } else {
307                    Ok(DataValue::Float(*a as f64 / *b as f64))
308                }
309            }
310            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 / b)),
311            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a / *b as f64)),
312            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a / b)),
313            _ => Err(anyhow!("Cannot divide {:?} and {:?}", left, right)),
314        }
315    }
316
317    /// Format a `DataValue` for debug output
318    fn format_value(&self, value: &DataValue) -> String {
319        match value {
320            DataValue::Integer(i) => i.to_string(),
321            DataValue::Float(f) => f.to_string(),
322            DataValue::String(s) => format!("'{s}'"),
323            _ => format!("{value:?}"),
324        }
325    }
326
327    /// Compare two `DataValues` using the provided comparison function
328    fn compare_values<F>(&self, left: &DataValue, right: &DataValue, op: F) -> Result<DataValue>
329    where
330        F: Fn(f64, f64) -> bool,
331    {
332        debug!(
333            "ArithmeticEvaluator: comparing values {:?} and {:?}",
334            left, right
335        );
336
337        let result = match (left, right) {
338            // Integer comparisons
339            (DataValue::Integer(a), DataValue::Integer(b)) => op(*a as f64, *b as f64),
340            (DataValue::Integer(a), DataValue::Float(b)) => op(*a as f64, *b),
341            (DataValue::Float(a), DataValue::Integer(b)) => op(*a, *b as f64),
342            (DataValue::Float(a), DataValue::Float(b)) => op(*a, *b),
343
344            // String comparisons (lexicographic)
345            (DataValue::String(a), DataValue::String(b)) => {
346                let a_num = a.parse::<f64>();
347                let b_num = b.parse::<f64>();
348                match (a_num, b_num) {
349                    (Ok(a_val), Ok(b_val)) => op(a_val, b_val), // Both are numbers
350                    _ => op(a.len() as f64, b.len() as f64),    // Fallback to length comparison
351                }
352            }
353            (DataValue::InternedString(a), DataValue::InternedString(b)) => {
354                let a_num = a.parse::<f64>();
355                let b_num = b.parse::<f64>();
356                match (a_num, b_num) {
357                    (Ok(a_val), Ok(b_val)) => op(a_val, b_val), // Both are numbers
358                    _ => op(a.len() as f64, b.len() as f64),    // Fallback to length comparison
359                }
360            }
361            (DataValue::String(a), DataValue::InternedString(b)) => {
362                let a_num = a.parse::<f64>();
363                let b_num = b.parse::<f64>();
364                match (a_num, b_num) {
365                    (Ok(a_val), Ok(b_val)) => op(a_val, b_val), // Both are numbers
366                    _ => op(a.len() as f64, b.len() as f64),    // Fallback to length comparison
367                }
368            }
369            (DataValue::InternedString(a), DataValue::String(b)) => {
370                let a_num = a.parse::<f64>();
371                let b_num = b.parse::<f64>();
372                match (a_num, b_num) {
373                    (Ok(a_val), Ok(b_val)) => op(a_val, b_val), // Both are numbers
374                    _ => op(a.len() as f64, b.len() as f64),    // Fallback to length comparison
375                }
376            }
377
378            // Mixed type comparisons (try to convert to numbers)
379            (DataValue::String(a), DataValue::Integer(b)) => {
380                match a.parse::<f64>() {
381                    Ok(a_val) => op(a_val, *b as f64),
382                    Err(_) => false, // String can't be compared with number
383                }
384            }
385            (DataValue::Integer(a), DataValue::String(b)) => {
386                match b.parse::<f64>() {
387                    Ok(b_val) => op(*a as f64, b_val),
388                    Err(_) => false, // String can't be compared with number
389                }
390            }
391            (DataValue::String(a), DataValue::Float(b)) => match a.parse::<f64>() {
392                Ok(a_val) => op(a_val, *b),
393                Err(_) => false,
394            },
395            (DataValue::Float(a), DataValue::String(b)) => match b.parse::<f64>() {
396                Ok(b_val) => op(*a, b_val),
397                Err(_) => false,
398            },
399
400            // NULL comparisons
401            (DataValue::Null, _) | (_, DataValue::Null) => false,
402
403            // Boolean comparisons
404            (DataValue::Boolean(a), DataValue::Boolean(b)) => {
405                op(if *a { 1.0 } else { 0.0 }, if *b { 1.0 } else { 0.0 })
406            }
407
408            _ => {
409                debug!(
410                    "ArithmeticEvaluator: unsupported comparison between {:?} and {:?}",
411                    left, right
412                );
413                false
414            }
415        };
416
417        debug!("ArithmeticEvaluator: comparison result: {}", result);
418        Ok(DataValue::Boolean(result))
419    }
420
421    /// Evaluate a function call
422    fn evaluate_function_with_distinct(
423        &mut self,
424        name: &str,
425        args: &[SqlExpression],
426        distinct: bool,
427        row_index: usize,
428    ) -> Result<DataValue> {
429        // If DISTINCT is specified, handle it specially for aggregate functions
430        if distinct {
431            let name_upper = name.to_uppercase();
432
433            // DISTINCT is only valid for aggregate functions
434            if name_upper == "COUNT"
435                || name_upper == "SUM"
436                || name_upper == "AVG"
437                || name_upper == "MIN"
438                || name_upper == "MAX"
439            {
440                return self.evaluate_aggregate_distinct(&name_upper, args, row_index);
441            } else {
442                return Err(anyhow!(
443                    "DISTINCT can only be used with aggregate functions"
444                ));
445            }
446        }
447
448        // Otherwise, use the regular evaluation
449        self.evaluate_function(name, args, row_index)
450    }
451
452    fn evaluate_aggregate_distinct(
453        &mut self,
454        name: &str,
455        args: &[SqlExpression],
456        row_index: usize,
457    ) -> Result<DataValue> {
458        use std::collections::HashSet;
459
460        if args.is_empty() {
461            return Err(anyhow!("{} DISTINCT requires at least one argument", name));
462        }
463
464        // Determine which rows to process
465        let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
466            visible.clone()
467        } else {
468            (0..self.table.rows.len()).collect()
469        };
470
471        // Collect unique values
472        let mut unique_values = HashSet::new();
473        let mut numeric_values = Vec::new();
474
475        for row_idx in &rows_to_process {
476            // Evaluate the expression for this row
477            let value = self.evaluate(&args[0], *row_idx)?;
478
479            // Skip NULL values
480            if matches!(value, DataValue::Null) {
481                continue;
482            }
483
484            // Convert to string for uniqueness check
485            let value_str = match &value {
486                DataValue::String(s) => s.clone(),
487                DataValue::InternedString(s) => s.to_string(),
488                DataValue::Integer(i) => i.to_string(),
489                DataValue::Float(f) => f.to_string(),
490                DataValue::Boolean(b) => b.to_string(),
491                DataValue::DateTime(dt) => dt.to_string(),
492                DataValue::Null => continue,
493            };
494
495            // Only process if we haven't seen this value before
496            if unique_values.insert(value_str) {
497                // For numeric aggregates, collect the numeric value
498                if name != "COUNT" {
499                    match value {
500                        DataValue::Integer(i) => numeric_values.push(i as f64),
501                        DataValue::Float(f) => numeric_values.push(f),
502                        _ => {} // Skip non-numeric for SUM/AVG
503                    }
504                }
505            }
506        }
507
508        // Calculate the result based on the aggregate function
509        match name {
510            "COUNT" => Ok(DataValue::Integer(unique_values.len() as i64)),
511            "SUM" => {
512                if numeric_values.is_empty() {
513                    Ok(DataValue::Null)
514                } else {
515                    let sum: f64 = numeric_values.iter().sum();
516                    if sum.fract() == 0.0 && sum.abs() < 1e10 {
517                        Ok(DataValue::Integer(sum as i64))
518                    } else {
519                        Ok(DataValue::Float(sum))
520                    }
521                }
522            }
523            "AVG" => {
524                if numeric_values.is_empty() {
525                    Ok(DataValue::Null)
526                } else {
527                    let sum: f64 = numeric_values.iter().sum();
528                    Ok(DataValue::Float(sum / numeric_values.len() as f64))
529                }
530            }
531            "MIN" => {
532                if numeric_values.is_empty() {
533                    Ok(DataValue::Null)
534                } else {
535                    let min = numeric_values.iter().fold(f64::INFINITY, |a, &b| a.min(b));
536                    if min.fract() == 0.0 && min.abs() < 1e10 {
537                        Ok(DataValue::Integer(min as i64))
538                    } else {
539                        Ok(DataValue::Float(min))
540                    }
541                }
542            }
543            "MAX" => {
544                if numeric_values.is_empty() {
545                    Ok(DataValue::Null)
546                } else {
547                    let max = numeric_values
548                        .iter()
549                        .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
550                    if max.fract() == 0.0 && max.abs() < 1e10 {
551                        Ok(DataValue::Integer(max as i64))
552                    } else {
553                        Ok(DataValue::Float(max))
554                    }
555                }
556            }
557            _ => Err(anyhow!("Unsupported DISTINCT aggregate: {}", name)),
558        }
559    }
560
561    fn evaluate_function(
562        &mut self,
563        name: &str,
564        args: &[SqlExpression],
565        row_index: usize,
566    ) -> Result<DataValue> {
567        // Check if this is an aggregate function
568        let name_upper = name.to_uppercase();
569
570        // Handle COUNT(*) special case
571        if name_upper == "COUNT" && args.len() == 1 {
572            match &args[0] {
573                SqlExpression::Column(col) if col == "*" => {
574                    // COUNT(*) - count all rows (or visible rows if filtered)
575                    let count = if let Some(ref visible) = self.visible_rows {
576                        visible.len() as i64
577                    } else {
578                        self.table.rows.len() as i64
579                    };
580                    return Ok(DataValue::Integer(count));
581                }
582                SqlExpression::StringLiteral(s) if s == "*" => {
583                    // COUNT(*) parsed as StringLiteral
584                    let count = if let Some(ref visible) = self.visible_rows {
585                        visible.len() as i64
586                    } else {
587                        self.table.rows.len() as i64
588                    };
589                    return Ok(DataValue::Integer(count));
590                }
591                _ => {
592                    // COUNT(column) - will be handled below
593                }
594            }
595        }
596
597        // Check aggregate registry
598        if self.aggregate_registry.get(&name_upper).is_some() {
599            // Determine which rows to process first
600            let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
601                visible.clone()
602            } else {
603                (0..self.table.rows.len()).collect()
604            };
605
606            // Evaluate arguments first if needed (to avoid borrow issues)
607            let values = if !args.is_empty()
608                && !(args.len() == 1 && matches!(&args[0], SqlExpression::Column(c) if c == "*"))
609            {
610                // Evaluate the argument expression for each row
611                let mut vals = Vec::new();
612                for &row_idx in &rows_to_process {
613                    let value = self.evaluate(&args[0], row_idx)?;
614                    vals.push(value);
615                }
616                Some(vals)
617            } else {
618                None
619            };
620
621            // Now get the aggregate function and process
622            let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
623            let mut state = agg_func.init();
624
625            if let Some(values) = values {
626                // Use evaluated values
627                for value in &values {
628                    agg_func.accumulate(&mut state, value)?;
629                }
630            } else {
631                // COUNT(*) case
632                for _ in &rows_to_process {
633                    agg_func.accumulate(&mut state, &DataValue::Integer(1))?;
634                }
635            }
636
637            return Ok(agg_func.finalize(state));
638        }
639
640        // First check if this function exists in the registry
641        if self.function_registry.get(name).is_some() {
642            // Evaluate all arguments first to avoid borrow issues
643            let mut evaluated_args = Vec::new();
644            for arg in args {
645                evaluated_args.push(self.evaluate(arg, row_index)?);
646            }
647
648            // Get the function and call it
649            let func = self.function_registry.get(name).unwrap();
650            return func.evaluate(&evaluated_args);
651        }
652
653        // If not in registry, return error for unknown function
654        Err(anyhow!("Unknown function: {}", name))
655    }
656
657    /// Get or create a WindowContext for the given specification
658    fn get_or_create_window_context(&mut self, spec: &WindowSpec) -> Result<Arc<WindowContext>> {
659        // Create a key for caching based on the spec
660        let key = format!("{:?}", spec);
661
662        if let Some(context) = self.window_contexts.get(&key) {
663            return Ok(Arc::clone(context));
664        }
665
666        // Create a DataView from the table (with visible rows if filtered)
667        let data_view = if let Some(ref _visible_rows) = self.visible_rows {
668            // Create a filtered view
669            let view = DataView::new(Arc::new(self.table.clone()));
670            // Apply filtering based on visible rows
671            // Note: This is a simplified approach - in production we'd need proper filtering
672            view
673        } else {
674            DataView::new(Arc::new(self.table.clone()))
675        };
676
677        // Create the WindowContext
678        let context = WindowContext::new(
679            Arc::new(data_view),
680            spec.partition_by.clone(),
681            spec.order_by.clone(),
682        )?;
683
684        let context = Arc::new(context);
685        self.window_contexts.insert(key, Arc::clone(&context));
686        Ok(context)
687    }
688
689    /// Evaluate a window function
690    fn evaluate_window_function(
691        &mut self,
692        name: &str,
693        args: &[SqlExpression],
694        spec: &WindowSpec,
695        row_index: usize,
696    ) -> Result<DataValue> {
697        let context = self.get_or_create_window_context(spec)?;
698        let name_upper = name.to_uppercase();
699
700        match name_upper.as_str() {
701            "LAG" => {
702                // LAG(column, offset, default)
703                if args.is_empty() {
704                    return Err(anyhow!("LAG requires at least 1 argument"));
705                }
706
707                // Get column name
708                let column = match &args[0] {
709                    SqlExpression::Column(col) => col.clone(),
710                    _ => return Err(anyhow!("LAG first argument must be a column")),
711                };
712
713                // Get offset (default 1)
714                let offset = if args.len() > 1 {
715                    match self.evaluate(&args[1], row_index)? {
716                        DataValue::Integer(i) => i as i32,
717                        _ => return Err(anyhow!("LAG offset must be an integer")),
718                    }
719                } else {
720                    1
721                };
722
723                // Get value at offset
724                Ok(context
725                    .get_offset_value(row_index, -offset, &column)
726                    .unwrap_or(DataValue::Null))
727            }
728            "LEAD" => {
729                // LEAD(column, offset, default)
730                if args.is_empty() {
731                    return Err(anyhow!("LEAD requires at least 1 argument"));
732                }
733
734                // Get column name
735                let column = match &args[0] {
736                    SqlExpression::Column(col) => col.clone(),
737                    _ => return Err(anyhow!("LEAD first argument must be a column")),
738                };
739
740                // Get offset (default 1)
741                let offset = if args.len() > 1 {
742                    match self.evaluate(&args[1], row_index)? {
743                        DataValue::Integer(i) => i as i32,
744                        _ => return Err(anyhow!("LEAD offset must be an integer")),
745                    }
746                } else {
747                    1
748                };
749
750                // Get value at offset
751                Ok(context
752                    .get_offset_value(row_index, offset, &column)
753                    .unwrap_or(DataValue::Null))
754            }
755            "ROW_NUMBER" => {
756                // ROW_NUMBER() - no arguments
757                Ok(DataValue::Integer(context.get_row_number(row_index) as i64))
758            }
759            "FIRST_VALUE" => {
760                // FIRST_VALUE(column)
761                if args.is_empty() {
762                    return Err(anyhow!("FIRST_VALUE requires 1 argument"));
763                }
764
765                let column = match &args[0] {
766                    SqlExpression::Column(col) => col.clone(),
767                    _ => return Err(anyhow!("FIRST_VALUE argument must be a column")),
768                };
769
770                Ok(context
771                    .get_first_value(row_index, &column)
772                    .unwrap_or(DataValue::Null))
773            }
774            "LAST_VALUE" => {
775                // LAST_VALUE(column)
776                if args.is_empty() {
777                    return Err(anyhow!("LAST_VALUE requires 1 argument"));
778                }
779
780                let column = match &args[0] {
781                    SqlExpression::Column(col) => col.clone(),
782                    _ => return Err(anyhow!("LAST_VALUE argument must be a column")),
783                };
784
785                Ok(context
786                    .get_last_value(row_index, &column)
787                    .unwrap_or(DataValue::Null))
788            }
789            "SUM" => {
790                // SUM(column) OVER (PARTITION BY ...)
791                if args.is_empty() {
792                    return Err(anyhow!("SUM requires 1 argument"));
793                }
794
795                let column = match &args[0] {
796                    SqlExpression::Column(col) => col.clone(),
797                    _ => return Err(anyhow!("SUM argument must be a column")),
798                };
799
800                Ok(context
801                    .get_partition_sum(row_index, &column)
802                    .unwrap_or(DataValue::Null))
803            }
804            "COUNT" => {
805                // COUNT(*) or COUNT(column) OVER (PARTITION BY ...)
806                // Note: In window functions, COUNT(*) seems to come with no args
807                if args.is_empty() {
808                    // COUNT(*) OVER (...) - count all rows in partition
809                    Ok(context
810                        .get_partition_count(row_index, None)
811                        .unwrap_or(DataValue::Null))
812                } else {
813                    // Check for COUNT(*)
814                    let column = match &args[0] {
815                        SqlExpression::Column(col) => {
816                            if col == "*" {
817                                // COUNT(*) - count all rows in partition
818                                return Ok(context
819                                    .get_partition_count(row_index, None)
820                                    .unwrap_or(DataValue::Null));
821                            }
822                            col.clone()
823                        }
824                        SqlExpression::StringLiteral(s) if s == "*" => {
825                            // COUNT(*) as StringLiteral (how parser sends it)
826                            return Ok(context
827                                .get_partition_count(row_index, None)
828                                .unwrap_or(DataValue::Null));
829                        }
830                        _ => return Err(anyhow!("COUNT argument must be a column or *")),
831                    };
832
833                    // COUNT(column) - count non-null values
834                    Ok(context
835                        .get_partition_count(row_index, Some(&column))
836                        .unwrap_or(DataValue::Null))
837                }
838            }
839            _ => Err(anyhow!("Unknown window function: {}", name)),
840        }
841    }
842
843    /// Evaluate a method call on a column (e.g., `column.Trim()`)
844    fn evaluate_method_call(
845        &mut self,
846        object: &str,
847        method: &str,
848        args: &[SqlExpression],
849        row_index: usize,
850    ) -> Result<DataValue> {
851        // Get column value
852        let col_index = self.table.get_column_index(object).ok_or_else(|| {
853            let suggestion = self.find_similar_column(object);
854            match suggestion {
855                Some(similar) => {
856                    anyhow!("Column '{}' not found. Did you mean '{}'?", object, similar)
857                }
858                None => anyhow!("Column '{}' not found", object),
859            }
860        })?;
861
862        let cell_value = self.table.get_value(row_index, col_index).cloned();
863
864        self.evaluate_method_on_value(
865            &cell_value.unwrap_or(DataValue::Null),
866            method,
867            args,
868            row_index,
869        )
870    }
871
872    /// Evaluate a method on a value
873    fn evaluate_method_on_value(
874        &mut self,
875        value: &DataValue,
876        method: &str,
877        args: &[SqlExpression],
878        row_index: usize,
879    ) -> Result<DataValue> {
880        // First, try to proxy the method through the function registry
881        // Many string methods have corresponding functions (TRIM, LENGTH, CONTAINS, etc.)
882
883        // Map method names to function names (case-insensitive matching)
884        let function_name = match method.to_lowercase().as_str() {
885            "trim" => "TRIM",
886            "trimstart" | "trimbegin" => "TRIMSTART",
887            "trimend" => "TRIMEND",
888            "length" | "len" => "LENGTH",
889            "contains" => "CONTAINS",
890            "startswith" => "STARTSWITH",
891            "endswith" => "ENDSWITH",
892            "indexof" => "INDEXOF",
893            _ => method, // Try the method name as-is
894        };
895
896        // Check if we have this function in the registry
897        if self.function_registry.get(function_name).is_some() {
898            debug!(
899                "Proxying method '{}' through function registry as '{}'",
900                method, function_name
901            );
902
903            // Prepare arguments: receiver is the first argument, followed by method args
904            let mut func_args = vec![value.clone()];
905
906            // Evaluate method arguments and add them
907            for arg in args {
908                func_args.push(self.evaluate(arg, row_index)?);
909            }
910
911            // Get the function and call it
912            let func = self.function_registry.get(function_name).unwrap();
913            return func.evaluate(&func_args);
914        }
915
916        // If not in registry, fall back to the old implementation for compatibility
917        match method.to_lowercase().as_str() {
918            "trim" | "trimstart" | "trimend" => {
919                if !args.is_empty() {
920                    return Err(anyhow!("{} takes no arguments", method));
921                }
922
923                // Convert value to string and apply trim
924                let str_val = match value {
925                    DataValue::String(s) => s.clone(),
926                    DataValue::InternedString(s) => s.to_string(),
927                    DataValue::Integer(n) => n.to_string(),
928                    DataValue::Float(f) => f.to_string(),
929                    DataValue::Boolean(b) => b.to_string(),
930                    DataValue::DateTime(dt) => dt.clone(),
931                    DataValue::Null => return Ok(DataValue::Null),
932                };
933
934                let result = match method.to_lowercase().as_str() {
935                    "trim" => str_val.trim().to_string(),
936                    "trimstart" => str_val.trim_start().to_string(),
937                    "trimend" => str_val.trim_end().to_string(),
938                    _ => unreachable!(),
939                };
940
941                Ok(DataValue::String(result))
942            }
943            "length" => {
944                if !args.is_empty() {
945                    return Err(anyhow!("Length takes no arguments"));
946                }
947
948                // Get string length
949                let len = match value {
950                    DataValue::String(s) => s.len(),
951                    DataValue::InternedString(s) => s.len(),
952                    DataValue::Integer(n) => n.to_string().len(),
953                    DataValue::Float(f) => f.to_string().len(),
954                    DataValue::Boolean(b) => b.to_string().len(),
955                    DataValue::DateTime(dt) => dt.len(),
956                    DataValue::Null => return Ok(DataValue::Integer(0)),
957                };
958
959                Ok(DataValue::Integer(len as i64))
960            }
961            "indexof" => {
962                if args.len() != 1 {
963                    return Err(anyhow!("IndexOf requires exactly 1 argument"));
964                }
965
966                // Get the search string from args
967                let search_str = match self.evaluate(&args[0], row_index)? {
968                    DataValue::String(s) => s,
969                    DataValue::InternedString(s) => s.to_string(),
970                    DataValue::Integer(n) => n.to_string(),
971                    DataValue::Float(f) => f.to_string(),
972                    _ => return Err(anyhow!("IndexOf argument must be a string")),
973                };
974
975                // Convert value to string and find index
976                let str_val = match value {
977                    DataValue::String(s) => s.clone(),
978                    DataValue::InternedString(s) => s.to_string(),
979                    DataValue::Integer(n) => n.to_string(),
980                    DataValue::Float(f) => f.to_string(),
981                    DataValue::Boolean(b) => b.to_string(),
982                    DataValue::DateTime(dt) => dt.clone(),
983                    DataValue::Null => return Ok(DataValue::Integer(-1)),
984                };
985
986                let index = str_val.find(&search_str).map_or(-1, |i| i as i64);
987
988                Ok(DataValue::Integer(index))
989            }
990            "contains" => {
991                if args.len() != 1 {
992                    return Err(anyhow!("Contains requires exactly 1 argument"));
993                }
994
995                // Get the search string from args
996                let search_str = match self.evaluate(&args[0], row_index)? {
997                    DataValue::String(s) => s,
998                    DataValue::InternedString(s) => s.to_string(),
999                    DataValue::Integer(n) => n.to_string(),
1000                    DataValue::Float(f) => f.to_string(),
1001                    _ => return Err(anyhow!("Contains argument must be a string")),
1002                };
1003
1004                // Convert value to string and check contains
1005                let str_val = match value {
1006                    DataValue::String(s) => s.clone(),
1007                    DataValue::InternedString(s) => s.to_string(),
1008                    DataValue::Integer(n) => n.to_string(),
1009                    DataValue::Float(f) => f.to_string(),
1010                    DataValue::Boolean(b) => b.to_string(),
1011                    DataValue::DateTime(dt) => dt.clone(),
1012                    DataValue::Null => return Ok(DataValue::Boolean(false)),
1013                };
1014
1015                // Case-insensitive search
1016                let result = str_val.to_lowercase().contains(&search_str.to_lowercase());
1017                Ok(DataValue::Boolean(result))
1018            }
1019            "startswith" => {
1020                if args.len() != 1 {
1021                    return Err(anyhow!("StartsWith requires exactly 1 argument"));
1022                }
1023
1024                // Get the prefix from args
1025                let prefix = match self.evaluate(&args[0], row_index)? {
1026                    DataValue::String(s) => s,
1027                    DataValue::InternedString(s) => s.to_string(),
1028                    DataValue::Integer(n) => n.to_string(),
1029                    DataValue::Float(f) => f.to_string(),
1030                    _ => return Err(anyhow!("StartsWith argument must be a string")),
1031                };
1032
1033                // Convert value to string and check starts_with
1034                let str_val = match value {
1035                    DataValue::String(s) => s.clone(),
1036                    DataValue::InternedString(s) => s.to_string(),
1037                    DataValue::Integer(n) => n.to_string(),
1038                    DataValue::Float(f) => f.to_string(),
1039                    DataValue::Boolean(b) => b.to_string(),
1040                    DataValue::DateTime(dt) => dt.clone(),
1041                    DataValue::Null => return Ok(DataValue::Boolean(false)),
1042                };
1043
1044                // Case-insensitive check
1045                let result = str_val.to_lowercase().starts_with(&prefix.to_lowercase());
1046                Ok(DataValue::Boolean(result))
1047            }
1048            "endswith" => {
1049                if args.len() != 1 {
1050                    return Err(anyhow!("EndsWith requires exactly 1 argument"));
1051                }
1052
1053                // Get the suffix from args
1054                let suffix = match self.evaluate(&args[0], row_index)? {
1055                    DataValue::String(s) => s,
1056                    DataValue::InternedString(s) => s.to_string(),
1057                    DataValue::Integer(n) => n.to_string(),
1058                    DataValue::Float(f) => f.to_string(),
1059                    _ => return Err(anyhow!("EndsWith argument must be a string")),
1060                };
1061
1062                // Convert value to string and check ends_with
1063                let str_val = match value {
1064                    DataValue::String(s) => s.clone(),
1065                    DataValue::InternedString(s) => s.to_string(),
1066                    DataValue::Integer(n) => n.to_string(),
1067                    DataValue::Float(f) => f.to_string(),
1068                    DataValue::Boolean(b) => b.to_string(),
1069                    DataValue::DateTime(dt) => dt.clone(),
1070                    DataValue::Null => return Ok(DataValue::Boolean(false)),
1071                };
1072
1073                // Case-insensitive check
1074                let result = str_val.to_lowercase().ends_with(&suffix.to_lowercase());
1075                Ok(DataValue::Boolean(result))
1076            }
1077            _ => Err(anyhow!("Unsupported method: {}", method)),
1078        }
1079    }
1080
1081    /// Evaluate a CASE expression
1082    fn evaluate_case_expression(
1083        &mut self,
1084        when_branches: &[crate::sql::recursive_parser::WhenBranch],
1085        else_branch: &Option<Box<SqlExpression>>,
1086        row_index: usize,
1087    ) -> Result<DataValue> {
1088        debug!(
1089            "ArithmeticEvaluator: evaluating CASE expression for row {}",
1090            row_index
1091        );
1092
1093        // Evaluate each WHEN condition in order
1094        for branch in when_branches {
1095            // Evaluate the condition as a boolean
1096            let condition_result = self.evaluate_condition_as_bool(&branch.condition, row_index)?;
1097
1098            if condition_result {
1099                debug!("CASE: WHEN condition matched, evaluating result expression");
1100                return self.evaluate(&branch.result, row_index);
1101            }
1102        }
1103
1104        // If no WHEN condition matched, evaluate ELSE clause (or return NULL)
1105        if let Some(else_expr) = else_branch {
1106            debug!("CASE: No WHEN matched, evaluating ELSE expression");
1107            self.evaluate(else_expr, row_index)
1108        } else {
1109            debug!("CASE: No WHEN matched and no ELSE, returning NULL");
1110            Ok(DataValue::Null)
1111        }
1112    }
1113
1114    /// Helper method to evaluate an expression as a boolean (for CASE WHEN conditions)
1115    fn evaluate_condition_as_bool(
1116        &mut self,
1117        expr: &SqlExpression,
1118        row_index: usize,
1119    ) -> Result<bool> {
1120        let value = self.evaluate(expr, row_index)?;
1121
1122        match value {
1123            DataValue::Boolean(b) => Ok(b),
1124            DataValue::Integer(i) => Ok(i != 0),
1125            DataValue::Float(f) => Ok(f != 0.0),
1126            DataValue::Null => Ok(false),
1127            DataValue::String(s) => Ok(!s.is_empty()),
1128            DataValue::InternedString(s) => Ok(!s.is_empty()),
1129            _ => Ok(true), // Other types are considered truthy
1130        }
1131    }
1132}
1133
1134#[cfg(test)]
1135mod tests {
1136    use super::*;
1137    use crate::data::datatable::{DataColumn, DataRow};
1138
1139    fn create_test_table() -> DataTable {
1140        let mut table = DataTable::new("test");
1141        table.add_column(DataColumn::new("a"));
1142        table.add_column(DataColumn::new("b"));
1143        table.add_column(DataColumn::new("c"));
1144
1145        table
1146            .add_row(DataRow::new(vec![
1147                DataValue::Integer(10),
1148                DataValue::Float(2.5),
1149                DataValue::Integer(4),
1150            ]))
1151            .unwrap();
1152
1153        table
1154    }
1155
1156    #[test]
1157    fn test_evaluate_column() {
1158        let table = create_test_table();
1159        let mut evaluator = ArithmeticEvaluator::new(&table);
1160
1161        let expr = SqlExpression::Column("a".to_string());
1162        let result = evaluator.evaluate(&expr, 0).unwrap();
1163        assert_eq!(result, DataValue::Integer(10));
1164    }
1165
1166    #[test]
1167    fn test_evaluate_number_literal() {
1168        let table = create_test_table();
1169        let mut evaluator = ArithmeticEvaluator::new(&table);
1170
1171        let expr = SqlExpression::NumberLiteral("42".to_string());
1172        let result = evaluator.evaluate(&expr, 0).unwrap();
1173        assert_eq!(result, DataValue::Integer(42));
1174
1175        let expr = SqlExpression::NumberLiteral("3.14".to_string());
1176        let result = evaluator.evaluate(&expr, 0).unwrap();
1177        assert_eq!(result, DataValue::Float(3.14));
1178    }
1179
1180    #[test]
1181    fn test_add_values() {
1182        let table = create_test_table();
1183        let mut evaluator = ArithmeticEvaluator::new(&table);
1184
1185        // Integer + Integer
1186        let result = evaluator
1187            .add_values(&DataValue::Integer(5), &DataValue::Integer(3))
1188            .unwrap();
1189        assert_eq!(result, DataValue::Integer(8));
1190
1191        // Integer + Float
1192        let result = evaluator
1193            .add_values(&DataValue::Integer(5), &DataValue::Float(2.5))
1194            .unwrap();
1195        assert_eq!(result, DataValue::Float(7.5));
1196    }
1197
1198    #[test]
1199    fn test_multiply_values() {
1200        let table = create_test_table();
1201        let mut evaluator = ArithmeticEvaluator::new(&table);
1202
1203        // Integer * Float
1204        let result = evaluator
1205            .multiply_values(&DataValue::Integer(4), &DataValue::Float(2.5))
1206            .unwrap();
1207        assert_eq!(result, DataValue::Float(10.0));
1208    }
1209
1210    #[test]
1211    fn test_divide_values() {
1212        let table = create_test_table();
1213        let mut evaluator = ArithmeticEvaluator::new(&table);
1214
1215        // Exact division
1216        let result = evaluator
1217            .divide_values(&DataValue::Integer(10), &DataValue::Integer(2))
1218            .unwrap();
1219        assert_eq!(result, DataValue::Integer(5));
1220
1221        // Non-exact division
1222        let result = evaluator
1223            .divide_values(&DataValue::Integer(10), &DataValue::Integer(3))
1224            .unwrap();
1225        assert_eq!(result, DataValue::Float(10.0 / 3.0));
1226    }
1227
1228    #[test]
1229    fn test_division_by_zero() {
1230        let table = create_test_table();
1231        let mut evaluator = ArithmeticEvaluator::new(&table);
1232
1233        let result = evaluator.divide_values(&DataValue::Integer(10), &DataValue::Integer(0));
1234        assert!(result.is_err());
1235        assert!(result.unwrap_err().to_string().contains("Division by zero"));
1236    }
1237
1238    #[test]
1239    fn test_binary_op_expression() {
1240        let table = create_test_table();
1241        let mut evaluator = ArithmeticEvaluator::new(&table);
1242
1243        // a * b where a=10, b=2.5
1244        let expr = SqlExpression::BinaryOp {
1245            left: Box::new(SqlExpression::Column("a".to_string())),
1246            op: "*".to_string(),
1247            right: Box::new(SqlExpression::Column("b".to_string())),
1248        };
1249
1250        let result = evaluator.evaluate(&expr, 0).unwrap();
1251        assert_eq!(result, DataValue::Float(25.0));
1252    }
1253}