sql_cli/data/
arithmetic_evaluator.rs

1use crate::config::global::get_date_notation;
2use crate::data::data_view::DataView;
3use crate::data::datatable::{DataTable, DataValue};
4use crate::data::value_comparisons::compare_with_op;
5use crate::sql::aggregate_functions::AggregateFunctionRegistry; // New registry
6use crate::sql::aggregates::AggregateRegistry; // Old registry (for migration)
7use crate::sql::functions::FunctionRegistry;
8use crate::sql::parser::ast::{ColumnRef, WindowSpec};
9use crate::sql::recursive_parser::SqlExpression;
10use crate::sql::window_context::WindowContext;
11use crate::sql::window_functions::{ExpressionEvaluator, WindowFunctionRegistry};
12use anyhow::{anyhow, Result};
13use std::collections::{HashMap, HashSet};
14use std::sync::Arc;
15use tracing::debug;
16
17/// Evaluates SQL expressions to compute `DataValues` (for SELECT clauses)
18/// This is different from `RecursiveWhereEvaluator` which returns boolean
19pub struct ArithmeticEvaluator<'a> {
20    table: &'a DataTable,
21    date_notation: String,
22    function_registry: Arc<FunctionRegistry>,
23    aggregate_registry: Arc<AggregateRegistry>, // Old registry (being phased out)
24    new_aggregate_registry: Arc<AggregateFunctionRegistry>, // New registry
25    window_function_registry: Arc<WindowFunctionRegistry>,
26    visible_rows: Option<Vec<usize>>, // For aggregate functions on filtered views
27    window_contexts: HashMap<String, Arc<WindowContext>>, // Cache window contexts by spec
28    table_aliases: HashMap<String, String>, // Map alias -> table name for qualified columns
29}
30
31impl<'a> ArithmeticEvaluator<'a> {
32    #[must_use]
33    pub fn new(table: &'a DataTable) -> Self {
34        Self {
35            table,
36            date_notation: get_date_notation(),
37            function_registry: Arc::new(FunctionRegistry::new()),
38            aggregate_registry: Arc::new(AggregateRegistry::new()),
39            new_aggregate_registry: Arc::new(AggregateFunctionRegistry::new()),
40            window_function_registry: Arc::new(WindowFunctionRegistry::new()),
41            visible_rows: None,
42            window_contexts: HashMap::new(),
43            table_aliases: HashMap::new(),
44        }
45    }
46
47    #[must_use]
48    pub fn with_date_notation(table: &'a DataTable, date_notation: String) -> Self {
49        Self {
50            table,
51            date_notation,
52            function_registry: Arc::new(FunctionRegistry::new()),
53            aggregate_registry: Arc::new(AggregateRegistry::new()),
54            new_aggregate_registry: Arc::new(AggregateFunctionRegistry::new()),
55            window_function_registry: Arc::new(WindowFunctionRegistry::new()),
56            visible_rows: None,
57            window_contexts: HashMap::new(),
58            table_aliases: HashMap::new(),
59        }
60    }
61
62    /// Set visible rows for aggregate functions (for filtered views)
63    #[must_use]
64    pub fn with_visible_rows(mut self, rows: Vec<usize>) -> Self {
65        self.visible_rows = Some(rows);
66        self
67    }
68
69    /// Set table aliases for qualified column resolution
70    #[must_use]
71    pub fn with_table_aliases(mut self, aliases: HashMap<String, String>) -> Self {
72        self.table_aliases = aliases;
73        self
74    }
75
76    #[must_use]
77    pub fn with_date_notation_and_registry(
78        table: &'a DataTable,
79        date_notation: String,
80        function_registry: Arc<FunctionRegistry>,
81    ) -> Self {
82        Self {
83            table,
84            date_notation,
85            function_registry,
86            aggregate_registry: Arc::new(AggregateRegistry::new()),
87            new_aggregate_registry: Arc::new(AggregateFunctionRegistry::new()),
88            window_function_registry: Arc::new(WindowFunctionRegistry::new()),
89            visible_rows: None,
90            window_contexts: HashMap::new(),
91            table_aliases: HashMap::new(),
92        }
93    }
94
95    /// Find a column name similar to the given name using edit distance
96    fn find_similar_column(&self, name: &str) -> Option<String> {
97        let columns = self.table.column_names();
98        let mut best_match: Option<(String, usize)> = None;
99
100        for col in columns {
101            let distance = self.edit_distance(&col.to_lowercase(), &name.to_lowercase());
102            // Only suggest if distance is small (likely a typo)
103            // Allow up to 3 edits for longer names
104            let max_distance = if name.len() > 10 { 3 } else { 2 };
105            if distance <= max_distance {
106                match &best_match {
107                    None => best_match = Some((col, distance)),
108                    Some((_, best_dist)) if distance < *best_dist => {
109                        best_match = Some((col, distance));
110                    }
111                    _ => {}
112                }
113            }
114        }
115
116        best_match.map(|(name, _)| name)
117    }
118
119    /// Calculate Levenshtein edit distance between two strings
120    fn edit_distance(&self, s1: &str, s2: &str) -> usize {
121        // Use the shared implementation from string_methods
122        crate::sql::functions::string_methods::EditDistanceFunction::calculate_edit_distance(s1, s2)
123    }
124
125    /// Evaluate an SQL expression to produce a `DataValue`
126    pub fn evaluate(&mut self, expr: &SqlExpression, row_index: usize) -> Result<DataValue> {
127        debug!(
128            "ArithmeticEvaluator: evaluating {:?} for row {}",
129            expr, row_index
130        );
131
132        match expr {
133            SqlExpression::Column(column_ref) => self.evaluate_column(&column_ref.name, row_index),
134            SqlExpression::StringLiteral(s) => Ok(DataValue::String(s.clone())),
135            SqlExpression::BooleanLiteral(b) => Ok(DataValue::Boolean(*b)),
136            SqlExpression::NumberLiteral(n) => self.evaluate_number_literal(n),
137            SqlExpression::Null => Ok(DataValue::Null),
138            SqlExpression::BinaryOp { left, op, right } => {
139                self.evaluate_binary_op(left, op, right, row_index)
140            }
141            SqlExpression::FunctionCall {
142                name,
143                args,
144                distinct,
145            } => self.evaluate_function_with_distinct(name, args, *distinct, row_index),
146            SqlExpression::WindowFunction {
147                name,
148                args,
149                window_spec,
150            } => self.evaluate_window_function(name, args, window_spec, row_index),
151            SqlExpression::MethodCall {
152                object,
153                method,
154                args,
155            } => self.evaluate_method_call(object, method, args, row_index),
156            SqlExpression::ChainedMethodCall { base, method, args } => {
157                // Evaluate the base expression first, then apply the method
158                let base_value = self.evaluate(base, row_index)?;
159                self.evaluate_method_on_value(&base_value, method, args, row_index)
160            }
161            SqlExpression::CaseExpression {
162                when_branches,
163                else_branch,
164            } => self.evaluate_case_expression(when_branches, else_branch, row_index),
165            SqlExpression::SimpleCaseExpression {
166                expr,
167                when_branches,
168                else_branch,
169            } => self.evaluate_simple_case_expression(expr, when_branches, else_branch, row_index),
170            _ => Err(anyhow!(
171                "Unsupported expression type for arithmetic evaluation: {:?}",
172                expr
173            )),
174        }
175    }
176
177    /// Evaluate a column reference
178    fn evaluate_column(&self, column_name: &str, row_index: usize) -> Result<DataValue> {
179        // First try to resolve qualified column names (table.column or alias.column)
180        let resolved_column = if column_name.contains('.') {
181            // Split on last dot to handle cases like "schema.table.column"
182            if let Some(dot_pos) = column_name.rfind('.') {
183                let _table_or_alias = &column_name[..dot_pos];
184                let col_name = &column_name[dot_pos + 1..];
185
186                // For now, just use the column name part
187                // In the future, we could validate the table/alias part
188                debug!(
189                    "Resolving qualified column: {} -> {}",
190                    column_name, col_name
191                );
192                col_name.to_string()
193            } else {
194                column_name.to_string()
195            }
196        } else {
197            column_name.to_string()
198        };
199
200        let col_index = if let Some(idx) = self.table.get_column_index(&resolved_column) {
201            idx
202        } else if resolved_column != column_name {
203            // If not found, try the original name
204            if let Some(idx) = self.table.get_column_index(column_name) {
205                idx
206            } else {
207                let suggestion = self.find_similar_column(&resolved_column);
208                return Err(match suggestion {
209                    Some(similar) => anyhow!(
210                        "Column '{}' not found. Did you mean '{}'?",
211                        column_name,
212                        similar
213                    ),
214                    None => anyhow!("Column '{}' not found", column_name),
215                });
216            }
217        } else {
218            let suggestion = self.find_similar_column(&resolved_column);
219            return Err(match suggestion {
220                Some(similar) => anyhow!(
221                    "Column '{}' not found. Did you mean '{}'?",
222                    column_name,
223                    similar
224                ),
225                None => anyhow!("Column '{}' not found", column_name),
226            });
227        };
228
229        if row_index >= self.table.row_count() {
230            return Err(anyhow!("Row index {} out of bounds", row_index));
231        }
232
233        let row = self
234            .table
235            .get_row(row_index)
236            .ok_or_else(|| anyhow!("Row {} not found", row_index))?;
237
238        let value = row
239            .get(col_index)
240            .ok_or_else(|| anyhow!("Column index {} out of bounds for row", col_index))?;
241
242        Ok(value.clone())
243    }
244
245    /// Evaluate a number literal (handles both integers and floats)
246    fn evaluate_number_literal(&self, number_str: &str) -> Result<DataValue> {
247        // Try to parse as integer first
248        if let Ok(int_val) = number_str.parse::<i64>() {
249            return Ok(DataValue::Integer(int_val));
250        }
251
252        // If that fails, try as float
253        if let Ok(float_val) = number_str.parse::<f64>() {
254            return Ok(DataValue::Float(float_val));
255        }
256
257        Err(anyhow!("Invalid number literal: {}", number_str))
258    }
259
260    /// Evaluate a binary operation (arithmetic)
261    fn evaluate_binary_op(
262        &mut self,
263        left: &SqlExpression,
264        op: &str,
265        right: &SqlExpression,
266        row_index: usize,
267    ) -> Result<DataValue> {
268        let left_val = self.evaluate(left, row_index)?;
269        let right_val = self.evaluate(right, row_index)?;
270
271        debug!(
272            "ArithmeticEvaluator: {} {} {}",
273            self.format_value(&left_val),
274            op,
275            self.format_value(&right_val)
276        );
277
278        match op {
279            "+" => self.add_values(&left_val, &right_val),
280            "-" => self.subtract_values(&left_val, &right_val),
281            "*" => self.multiply_values(&left_val, &right_val),
282            "/" => self.divide_values(&left_val, &right_val),
283            "%" => {
284                // Modulo operator - call MOD function
285                let args = vec![left.clone(), right.clone()];
286                self.evaluate_function("MOD", &args, row_index)
287            }
288            // Comparison operators (return boolean results)
289            // Use centralized comparison logic for consistency
290            ">" | "<" | ">=" | "<=" | "=" | "!=" | "<>" => {
291                let result = compare_with_op(&left_val, &right_val, op, false);
292                Ok(DataValue::Boolean(result))
293            }
294            // IS NULL / IS NOT NULL operators
295            "IS NULL" => Ok(DataValue::Boolean(matches!(left_val, DataValue::Null))),
296            "IS NOT NULL" => Ok(DataValue::Boolean(!matches!(left_val, DataValue::Null))),
297            // Logical operators
298            "AND" => {
299                let left_bool = self.to_bool(&left_val)?;
300                let right_bool = self.to_bool(&right_val)?;
301                Ok(DataValue::Boolean(left_bool && right_bool))
302            }
303            "OR" => {
304                let left_bool = self.to_bool(&left_val)?;
305                let right_bool = self.to_bool(&right_val)?;
306                Ok(DataValue::Boolean(left_bool || right_bool))
307            }
308            _ => Err(anyhow!("Unsupported arithmetic operator: {}", op)),
309        }
310    }
311
312    /// Add two `DataValues` with type coercion
313    fn add_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
314        // NULL handling - any operation with NULL returns NULL
315        if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
316            return Ok(DataValue::Null);
317        }
318
319        match (left, right) {
320            (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a + b)),
321            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 + b)),
322            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a + *b as f64)),
323            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a + b)),
324            _ => Err(anyhow!("Cannot add {:?} and {:?}", left, right)),
325        }
326    }
327
328    /// Subtract two `DataValues` with type coercion
329    fn subtract_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
330        // NULL handling - any operation with NULL returns NULL
331        if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
332            return Ok(DataValue::Null);
333        }
334
335        match (left, right) {
336            (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a - b)),
337            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 - b)),
338            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a - *b as f64)),
339            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a - b)),
340            _ => Err(anyhow!("Cannot subtract {:?} and {:?}", left, right)),
341        }
342    }
343
344    /// Multiply two `DataValues` with type coercion
345    fn multiply_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
346        // NULL handling - any operation with NULL returns NULL
347        if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
348            return Ok(DataValue::Null);
349        }
350
351        match (left, right) {
352            (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a * b)),
353            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 * b)),
354            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a * *b as f64)),
355            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a * b)),
356            _ => Err(anyhow!("Cannot multiply {:?} and {:?}", left, right)),
357        }
358    }
359
360    /// Divide two `DataValues` with type coercion
361    fn divide_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
362        // NULL handling - any operation with NULL returns NULL
363        if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
364            return Ok(DataValue::Null);
365        }
366
367        // Check for division by zero first
368        let is_zero = match right {
369            DataValue::Integer(0) => true,
370            DataValue::Float(f) if *f == 0.0 => true, // Only check for exact zero, not epsilon
371            _ => false,
372        };
373
374        if is_zero {
375            return Err(anyhow!("Division by zero"));
376        }
377
378        match (left, right) {
379            (DataValue::Integer(a), DataValue::Integer(b)) => {
380                // Integer division - if result is exact, keep as int, otherwise promote to float
381                if a % b == 0 {
382                    Ok(DataValue::Integer(a / b))
383                } else {
384                    Ok(DataValue::Float(*a as f64 / *b as f64))
385                }
386            }
387            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 / b)),
388            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a / *b as f64)),
389            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a / b)),
390            _ => Err(anyhow!("Cannot divide {:?} and {:?}", left, right)),
391        }
392    }
393
394    /// Format a `DataValue` for debug output
395    fn format_value(&self, value: &DataValue) -> String {
396        match value {
397            DataValue::Integer(i) => i.to_string(),
398            DataValue::Float(f) => f.to_string(),
399            DataValue::String(s) => format!("'{s}'"),
400            _ => format!("{value:?}"),
401        }
402    }
403
404    /// Convert a DataValue to boolean for logical operations
405    fn to_bool(&self, value: &DataValue) -> Result<bool> {
406        match value {
407            DataValue::Boolean(b) => Ok(*b),
408            DataValue::Integer(i) => Ok(*i != 0),
409            DataValue::Float(f) => Ok(*f != 0.0),
410            DataValue::Null => Ok(false),
411            _ => Err(anyhow!("Cannot convert {:?} to boolean", value)),
412        }
413    }
414
415    /// Evaluate a function call
416    fn evaluate_function_with_distinct(
417        &mut self,
418        name: &str,
419        args: &[SqlExpression],
420        distinct: bool,
421        row_index: usize,
422    ) -> Result<DataValue> {
423        // If DISTINCT is specified, handle it specially for aggregate functions
424        if distinct {
425            let name_upper = name.to_uppercase();
426
427            // Check if it's an aggregate function in either registry
428            if self.aggregate_registry.is_aggregate(&name_upper)
429                || self.new_aggregate_registry.contains(&name_upper)
430            {
431                return self.evaluate_aggregate_with_distinct(&name_upper, args, row_index);
432            } else {
433                return Err(anyhow!(
434                    "DISTINCT can only be used with aggregate functions"
435                ));
436            }
437        }
438
439        // Otherwise, use the regular evaluation
440        self.evaluate_function(name, args, row_index)
441    }
442
443    fn evaluate_aggregate_with_distinct(
444        &mut self,
445        name: &str,
446        args: &[SqlExpression],
447        _row_index: usize,
448    ) -> Result<DataValue> {
449        let name_upper = name.to_uppercase();
450
451        // Check new aggregate registry first for migrated functions
452        if self.new_aggregate_registry.get(&name_upper).is_some() {
453            let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
454                visible.clone()
455            } else {
456                (0..self.table.rows.len()).collect()
457            };
458
459            // Collect and deduplicate values for DISTINCT
460            let mut vals = Vec::new();
461            for &row_idx in &rows_to_process {
462                if !args.is_empty() {
463                    let value = self.evaluate(&args[0], row_idx)?;
464                    vals.push(value);
465                }
466            }
467
468            // Deduplicate values
469            let mut seen = HashSet::new();
470            let unique_values: Vec<_> = vals
471                .into_iter()
472                .filter(|v| {
473                    let key = format!("{:?}", v);
474                    seen.insert(key)
475                })
476                .collect();
477
478            // Get the aggregate function from the new registry
479            let agg_func = self.new_aggregate_registry.get(&name_upper).unwrap();
480            let mut state = agg_func.create_state();
481
482            // Use unique values
483            for value in &unique_values {
484                state.accumulate(value)?;
485            }
486
487            return Ok(state.finalize());
488        }
489
490        // Check old aggregate registry (DISTINCT handling)
491        if self.aggregate_registry.get(&name_upper).is_some() {
492            // Determine which rows to process first
493            let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
494                visible.clone()
495            } else {
496                (0..self.table.rows.len()).collect()
497            };
498
499            // Special handling for STRING_AGG with separator parameter
500            if name_upper == "STRING_AGG" && args.len() >= 2 {
501                // STRING_AGG(DISTINCT column, separator)
502                let mut state = crate::sql::aggregates::AggregateState::StringAgg(
503                    // Evaluate the separator (second argument) once
504                    if args.len() >= 2 {
505                        let separator = self.evaluate(&args[1], 0)?; // Separator doesn't depend on row
506                        match separator {
507                            DataValue::String(s) => crate::sql::aggregates::StringAggState::new(&s),
508                            DataValue::InternedString(s) => {
509                                crate::sql::aggregates::StringAggState::new(&s)
510                            }
511                            _ => crate::sql::aggregates::StringAggState::new(","), // Default separator
512                        }
513                    } else {
514                        crate::sql::aggregates::StringAggState::new(",")
515                    },
516                );
517
518                // Evaluate the first argument (column) for each row and accumulate
519                // Handle DISTINCT - use a HashSet to track seen values
520                let mut seen_values = HashSet::new();
521
522                for &row_idx in &rows_to_process {
523                    let value = self.evaluate(&args[0], row_idx)?;
524
525                    // Skip if we've seen this value
526                    if !seen_values.insert(value.clone()) {
527                        continue; // Skip duplicate values
528                    }
529
530                    // Now get the aggregate function and accumulate
531                    let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
532                    agg_func.accumulate(&mut state, &value)?;
533                }
534
535                // Finalize the aggregate
536                let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
537                return Ok(agg_func.finalize(state));
538            }
539
540            // For other aggregates with DISTINCT
541            // Evaluate the argument expression for each row
542            let mut vals = Vec::new();
543            for &row_idx in &rows_to_process {
544                if !args.is_empty() {
545                    let value = self.evaluate(&args[0], row_idx)?;
546                    vals.push(value);
547                }
548            }
549
550            // Deduplicate values for DISTINCT
551            let mut seen = HashSet::new();
552            let mut unique_values = Vec::new();
553            for value in vals {
554                if seen.insert(value.clone()) {
555                    unique_values.push(value);
556                }
557            }
558
559            // Now get the aggregate function and process
560            let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
561            let mut state = agg_func.init();
562
563            // Use unique values
564            for value in &unique_values {
565                agg_func.accumulate(&mut state, value)?;
566            }
567
568            return Ok(agg_func.finalize(state));
569        }
570
571        Err(anyhow!("Unknown aggregate function: {}", name))
572    }
573
574    fn evaluate_function(
575        &mut self,
576        name: &str,
577        args: &[SqlExpression],
578        row_index: usize,
579    ) -> Result<DataValue> {
580        // Check if this is an aggregate function
581        let name_upper = name.to_uppercase();
582
583        // Check new aggregate registry first (for migrated functions)
584        if self.new_aggregate_registry.get(&name_upper).is_some() {
585            // Use new registry for SUM
586            let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
587                visible.clone()
588            } else {
589                (0..self.table.rows.len()).collect()
590            };
591
592            // Get the aggregate function from the new registry
593            let agg_func = self.new_aggregate_registry.get(&name_upper).unwrap();
594            let mut state = agg_func.create_state();
595
596            // Special handling for COUNT(*)
597            if name_upper == "COUNT" || name_upper == "COUNT_STAR" {
598                if args.is_empty()
599                    || (args.len() == 1
600                        && matches!(&args[0], SqlExpression::Column(col) if col.name == "*"))
601                    || (args.len() == 1
602                        && matches!(&args[0], SqlExpression::StringLiteral(s) if s == "*"))
603                {
604                    // COUNT(*) or COUNT_STAR - count all rows
605                    for _ in &rows_to_process {
606                        state.accumulate(&DataValue::Integer(1))?;
607                    }
608                } else {
609                    // COUNT(column) - count non-null values
610                    for &row_idx in &rows_to_process {
611                        let value = self.evaluate(&args[0], row_idx)?;
612                        state.accumulate(&value)?;
613                    }
614                }
615            } else {
616                // Other aggregates - evaluate arguments and accumulate
617                if !args.is_empty() {
618                    for &row_idx in &rows_to_process {
619                        let value = self.evaluate(&args[0], row_idx)?;
620                        state.accumulate(&value)?;
621                    }
622                }
623            }
624
625            return Ok(state.finalize());
626        }
627
628        // Check old aggregate registry (for non-migrated functions)
629        if self.aggregate_registry.get(&name_upper).is_some() {
630            // Determine which rows to process first
631            let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
632                visible.clone()
633            } else {
634                (0..self.table.rows.len()).collect()
635            };
636
637            // Special handling for STRING_AGG with separator parameter
638            if name_upper == "STRING_AGG" && args.len() >= 2 {
639                // STRING_AGG(column, separator) - without DISTINCT (handled separately)
640                let mut state = crate::sql::aggregates::AggregateState::StringAgg(
641                    // Evaluate the separator (second argument) once
642                    if args.len() >= 2 {
643                        let separator = self.evaluate(&args[1], 0)?; // Separator doesn't depend on row
644                        match separator {
645                            DataValue::String(s) => crate::sql::aggregates::StringAggState::new(&s),
646                            DataValue::InternedString(s) => {
647                                crate::sql::aggregates::StringAggState::new(&s)
648                            }
649                            _ => crate::sql::aggregates::StringAggState::new(","), // Default separator
650                        }
651                    } else {
652                        crate::sql::aggregates::StringAggState::new(",")
653                    },
654                );
655
656                // Evaluate the first argument (column) for each row and accumulate
657                for &row_idx in &rows_to_process {
658                    let value = self.evaluate(&args[0], row_idx)?;
659                    // Now get the aggregate function and accumulate
660                    let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
661                    agg_func.accumulate(&mut state, &value)?;
662                }
663
664                // Finalize the aggregate
665                let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
666                return Ok(agg_func.finalize(state));
667            }
668
669            // Evaluate arguments first if needed (to avoid borrow issues)
670            let values = if !args.is_empty()
671                && !(args.len() == 1
672                    && matches!(&args[0], SqlExpression::Column(c) if c.name == "*"))
673            {
674                // Evaluate the argument expression for each row
675                let mut vals = Vec::new();
676                for &row_idx in &rows_to_process {
677                    let value = self.evaluate(&args[0], row_idx)?;
678                    vals.push(value);
679                }
680                Some(vals)
681            } else {
682                None
683            };
684
685            // Now get the aggregate function and process
686            let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
687            let mut state = agg_func.init();
688
689            if let Some(values) = values {
690                // Use evaluated values (DISTINCT is handled in evaluate_aggregate_with_distinct)
691                for value in &values {
692                    agg_func.accumulate(&mut state, value)?;
693                }
694            } else {
695                // COUNT(*) case
696                for _ in &rows_to_process {
697                    agg_func.accumulate(&mut state, &DataValue::Integer(1))?;
698                }
699            }
700
701            return Ok(agg_func.finalize(state));
702        }
703
704        // First check if this function exists in the registry
705        if self.function_registry.get(name).is_some() {
706            // Evaluate all arguments first to avoid borrow issues
707            let mut evaluated_args = Vec::new();
708            for arg in args {
709                evaluated_args.push(self.evaluate(arg, row_index)?);
710            }
711
712            // Get the function and call it
713            let func = self.function_registry.get(name).unwrap();
714            return func.evaluate(&evaluated_args);
715        }
716
717        // If not in registry, return error for unknown function
718        Err(anyhow!("Unknown function: {}", name))
719    }
720
721    /// Get or create a WindowContext for the given specification
722    fn get_or_create_window_context(&mut self, spec: &WindowSpec) -> Result<Arc<WindowContext>> {
723        // Create a key for caching based on the spec
724        let key = format!("{:?}", spec);
725
726        if let Some(context) = self.window_contexts.get(&key) {
727            return Ok(Arc::clone(context));
728        }
729
730        // Create a DataView from the table (with visible rows if filtered)
731        let data_view = if let Some(ref _visible_rows) = self.visible_rows {
732            // Create a filtered view
733            let view = DataView::new(Arc::new(self.table.clone()));
734            // Apply filtering based on visible rows
735            // Note: This is a simplified approach - in production we'd need proper filtering
736            view
737        } else {
738            DataView::new(Arc::new(self.table.clone()))
739        };
740
741        // Create the WindowContext with the full spec (including frame)
742        let context = WindowContext::new_with_spec(Arc::new(data_view), spec.clone())?;
743
744        let context = Arc::new(context);
745        self.window_contexts.insert(key, Arc::clone(&context));
746        Ok(context)
747    }
748
749    /// Evaluate a window function
750    fn evaluate_window_function(
751        &mut self,
752        name: &str,
753        args: &[SqlExpression],
754        spec: &WindowSpec,
755        row_index: usize,
756    ) -> Result<DataValue> {
757        let name_upper = name.to_uppercase();
758
759        // First check if this is a syntactic sugar function in the registry
760        debug!("Looking for window function {} in registry", name_upper);
761        if let Some(window_fn_arc) = self.window_function_registry.get(&name_upper) {
762            debug!("Found window function {} in registry", name_upper);
763
764            // Dereference to get the actual window function
765            let window_fn = window_fn_arc.as_ref();
766
767            // Validate arguments
768            window_fn.validate_args(args)?;
769
770            // Transform the window spec based on the function's requirements
771            let transformed_spec = window_fn.transform_window_spec(spec, args)?;
772
773            // Get or create the window context with the transformed spec
774            let context = self.get_or_create_window_context(&transformed_spec)?;
775
776            // Create an expression evaluator adapter
777            struct EvaluatorAdapter<'a, 'b> {
778                evaluator: &'a mut ArithmeticEvaluator<'b>,
779                row_index: usize,
780            }
781
782            impl<'a, 'b> ExpressionEvaluator for EvaluatorAdapter<'a, 'b> {
783                fn evaluate(
784                    &mut self,
785                    expr: &SqlExpression,
786                    _row_index: usize,
787                ) -> Result<DataValue> {
788                    self.evaluator.evaluate(expr, self.row_index)
789                }
790            }
791
792            let mut adapter = EvaluatorAdapter {
793                evaluator: self,
794                row_index,
795            };
796
797            // Call the window function's compute method
798            return window_fn.compute(&context, row_index, args, &mut adapter);
799        }
800
801        // Fall back to built-in window functions
802        let context = self.get_or_create_window_context(spec)?;
803
804        match name_upper.as_str() {
805            "LAG" => {
806                // LAG(column, offset, default)
807                if args.is_empty() {
808                    return Err(anyhow!("LAG requires at least 1 argument"));
809                }
810
811                // Get column name
812                let column = match &args[0] {
813                    SqlExpression::Column(col) => col.clone(),
814                    _ => return Err(anyhow!("LAG first argument must be a column")),
815                };
816
817                // Get offset (default 1)
818                let offset = if args.len() > 1 {
819                    match self.evaluate(&args[1], row_index)? {
820                        DataValue::Integer(i) => i as i32,
821                        _ => return Err(anyhow!("LAG offset must be an integer")),
822                    }
823                } else {
824                    1
825                };
826
827                // Get value at offset
828                Ok(context
829                    .get_offset_value(row_index, -offset, &column.name)
830                    .unwrap_or(DataValue::Null))
831            }
832            "LEAD" => {
833                // LEAD(column, offset, default)
834                if args.is_empty() {
835                    return Err(anyhow!("LEAD requires at least 1 argument"));
836                }
837
838                // Get column name
839                let column = match &args[0] {
840                    SqlExpression::Column(col) => col.clone(),
841                    _ => return Err(anyhow!("LEAD first argument must be a column")),
842                };
843
844                // Get offset (default 1)
845                let offset = if args.len() > 1 {
846                    match self.evaluate(&args[1], row_index)? {
847                        DataValue::Integer(i) => i as i32,
848                        _ => return Err(anyhow!("LEAD offset must be an integer")),
849                    }
850                } else {
851                    1
852                };
853
854                // Get value at offset
855                Ok(context
856                    .get_offset_value(row_index, offset, &column.name)
857                    .unwrap_or(DataValue::Null))
858            }
859            "ROW_NUMBER" => {
860                // ROW_NUMBER() - no arguments
861                Ok(DataValue::Integer(context.get_row_number(row_index) as i64))
862            }
863            "FIRST_VALUE" => {
864                // FIRST_VALUE(column) OVER (... ROWS ...)
865                if args.is_empty() {
866                    return Err(anyhow!("FIRST_VALUE requires 1 argument"));
867                }
868
869                let column = match &args[0] {
870                    SqlExpression::Column(col) => col.clone(),
871                    _ => return Err(anyhow!("FIRST_VALUE argument must be a column")),
872                };
873
874                // Use frame-aware version if frame is specified
875                if context.has_frame() {
876                    Ok(context
877                        .get_frame_first_value(row_index, &column.name)
878                        .unwrap_or(DataValue::Null))
879                } else {
880                    Ok(context
881                        .get_first_value(row_index, &column.name)
882                        .unwrap_or(DataValue::Null))
883                }
884            }
885            "LAST_VALUE" => {
886                // LAST_VALUE(column) OVER (... ROWS ...)
887                if args.is_empty() {
888                    return Err(anyhow!("LAST_VALUE requires 1 argument"));
889                }
890
891                let column = match &args[0] {
892                    SqlExpression::Column(col) => col.clone(),
893                    _ => return Err(anyhow!("LAST_VALUE argument must be a column")),
894                };
895
896                // Use frame-aware version if frame is specified
897                if context.has_frame() {
898                    Ok(context
899                        .get_frame_last_value(row_index, &column.name)
900                        .unwrap_or(DataValue::Null))
901                } else {
902                    Ok(context
903                        .get_last_value(row_index, &column.name)
904                        .unwrap_or(DataValue::Null))
905                }
906            }
907            "SUM" => {
908                // SUM(column) OVER (PARTITION BY ... ROWS n PRECEDING)
909                if args.is_empty() {
910                    return Err(anyhow!("SUM requires 1 argument"));
911                }
912
913                let column = match &args[0] {
914                    SqlExpression::Column(col) => col.clone(),
915                    _ => return Err(anyhow!("SUM argument must be a column")),
916                };
917
918                // Use frame-aware sum if frame is specified, otherwise use partition sum
919                if context.has_frame() {
920                    Ok(context
921                        .get_frame_sum(row_index, &column.name)
922                        .unwrap_or(DataValue::Null))
923                } else {
924                    Ok(context
925                        .get_partition_sum(row_index, &column.name)
926                        .unwrap_or(DataValue::Null))
927                }
928            }
929            "AVG" => {
930                // AVG(column) OVER (PARTITION BY ... ROWS n PRECEDING)
931                if args.is_empty() {
932                    return Err(anyhow!("AVG requires 1 argument"));
933                }
934
935                let column = match &args[0] {
936                    SqlExpression::Column(col) => col.clone(),
937                    _ => return Err(anyhow!("AVG argument must be a column")),
938                };
939
940                Ok(context
941                    .get_frame_avg(row_index, &column.name)
942                    .unwrap_or(DataValue::Null))
943            }
944            "STDDEV" | "STDEV" => {
945                // STDDEV(column) OVER (PARTITION BY ... ROWS n PRECEDING)
946                if args.is_empty() {
947                    return Err(anyhow!("STDDEV requires 1 argument"));
948                }
949
950                let column = match &args[0] {
951                    SqlExpression::Column(col) => col.clone(),
952                    _ => return Err(anyhow!("STDDEV argument must be a column")),
953                };
954
955                Ok(context
956                    .get_frame_stddev(row_index, &column.name)
957                    .unwrap_or(DataValue::Null))
958            }
959            "VARIANCE" | "VAR" => {
960                // VARIANCE(column) OVER (PARTITION BY ... ROWS n PRECEDING)
961                if args.is_empty() {
962                    return Err(anyhow!("VARIANCE requires 1 argument"));
963                }
964
965                let column = match &args[0] {
966                    SqlExpression::Column(col) => col.clone(),
967                    _ => return Err(anyhow!("VARIANCE argument must be a column")),
968                };
969
970                Ok(context
971                    .get_frame_variance(row_index, &column.name)
972                    .unwrap_or(DataValue::Null))
973            }
974            "MIN" => {
975                // MIN(column) OVER (PARTITION BY ... ROWS n PRECEDING)
976                if args.is_empty() {
977                    return Err(anyhow!("MIN requires 1 argument"));
978                }
979
980                let column = match &args[0] {
981                    SqlExpression::Column(col) => col.clone(),
982                    _ => return Err(anyhow!("MIN argument must be a column")),
983                };
984
985                let frame_rows = context.get_frame_rows(row_index);
986                if frame_rows.is_empty() {
987                    return Ok(DataValue::Null);
988                }
989
990                let source_table = context.source();
991                let col_idx = source_table
992                    .get_column_index(&column.name)
993                    .ok_or_else(|| anyhow!("Column '{}' not found", column.name))?;
994
995                let mut min_value: Option<DataValue> = None;
996                for &row_idx in &frame_rows {
997                    if let Some(value) = source_table.get_value(row_idx, col_idx) {
998                        if !matches!(value, DataValue::Null) {
999                            match &min_value {
1000                                None => min_value = Some(value.clone()),
1001                                Some(current_min) => {
1002                                    if value < current_min {
1003                                        min_value = Some(value.clone());
1004                                    }
1005                                }
1006                            }
1007                        }
1008                    }
1009                }
1010
1011                Ok(min_value.unwrap_or(DataValue::Null))
1012            }
1013            "MAX" => {
1014                // MAX(column) OVER (PARTITION BY ... ROWS n PRECEDING)
1015                if args.is_empty() {
1016                    return Err(anyhow!("MAX requires 1 argument"));
1017                }
1018
1019                let column = match &args[0] {
1020                    SqlExpression::Column(col) => col.clone(),
1021                    _ => return Err(anyhow!("MAX argument must be a column")),
1022                };
1023
1024                let frame_rows = context.get_frame_rows(row_index);
1025                if frame_rows.is_empty() {
1026                    return Ok(DataValue::Null);
1027                }
1028
1029                let source_table = context.source();
1030                let col_idx = source_table
1031                    .get_column_index(&column.name)
1032                    .ok_or_else(|| anyhow!("Column '{}' not found", column.name))?;
1033
1034                let mut max_value: Option<DataValue> = None;
1035                for &row_idx in &frame_rows {
1036                    if let Some(value) = source_table.get_value(row_idx, col_idx) {
1037                        if !matches!(value, DataValue::Null) {
1038                            match &max_value {
1039                                None => max_value = Some(value.clone()),
1040                                Some(current_max) => {
1041                                    if value > current_max {
1042                                        max_value = Some(value.clone());
1043                                    }
1044                                }
1045                            }
1046                        }
1047                    }
1048                }
1049
1050                Ok(max_value.unwrap_or(DataValue::Null))
1051            }
1052            "COUNT" => {
1053                // COUNT(*) or COUNT(column) OVER (PARTITION BY ... ROWS n PRECEDING)
1054                // Use frame-aware count if frame is specified, otherwise use partition count
1055
1056                if args.is_empty() {
1057                    // COUNT(*) OVER (...)
1058                    if context.has_frame() {
1059                        Ok(context
1060                            .get_frame_count(row_index, None)
1061                            .unwrap_or(DataValue::Null))
1062                    } else {
1063                        Ok(context
1064                            .get_partition_count(row_index, None)
1065                            .unwrap_or(DataValue::Null))
1066                    }
1067                } else {
1068                    // Check for COUNT(*)
1069                    let column = match &args[0] {
1070                        SqlExpression::Column(col) => {
1071                            if col.name == "*" {
1072                                // COUNT(*) - count all rows
1073                                if context.has_frame() {
1074                                    return Ok(context
1075                                        .get_frame_count(row_index, None)
1076                                        .unwrap_or(DataValue::Null));
1077                                } else {
1078                                    return Ok(context
1079                                        .get_partition_count(row_index, None)
1080                                        .unwrap_or(DataValue::Null));
1081                                }
1082                            }
1083                            col.clone()
1084                        }
1085                        SqlExpression::StringLiteral(s) if s == "*" => {
1086                            // COUNT(*) as StringLiteral
1087                            if context.has_frame() {
1088                                return Ok(context
1089                                    .get_frame_count(row_index, None)
1090                                    .unwrap_or(DataValue::Null));
1091                            } else {
1092                                return Ok(context
1093                                    .get_partition_count(row_index, None)
1094                                    .unwrap_or(DataValue::Null));
1095                            }
1096                        }
1097                        _ => return Err(anyhow!("COUNT argument must be a column or *")),
1098                    };
1099
1100                    // COUNT(column) - count non-null values
1101                    if context.has_frame() {
1102                        Ok(context
1103                            .get_frame_count(row_index, Some(&column.name))
1104                            .unwrap_or(DataValue::Null))
1105                    } else {
1106                        Ok(context
1107                            .get_partition_count(row_index, Some(&column.name))
1108                            .unwrap_or(DataValue::Null))
1109                    }
1110                }
1111            }
1112            _ => Err(anyhow!("Unknown window function: {}", name)),
1113        }
1114    }
1115
1116    /// Evaluate a method call on a column (e.g., `column.Trim()`)
1117    fn evaluate_method_call(
1118        &mut self,
1119        object: &str,
1120        method: &str,
1121        args: &[SqlExpression],
1122        row_index: usize,
1123    ) -> Result<DataValue> {
1124        // Get column value
1125        let col_index = self.table.get_column_index(object).ok_or_else(|| {
1126            let suggestion = self.find_similar_column(object);
1127            match suggestion {
1128                Some(similar) => {
1129                    anyhow!("Column '{}' not found. Did you mean '{}'?", object, similar)
1130                }
1131                None => anyhow!("Column '{}' not found", object),
1132            }
1133        })?;
1134
1135        let cell_value = self.table.get_value(row_index, col_index).cloned();
1136
1137        self.evaluate_method_on_value(
1138            &cell_value.unwrap_or(DataValue::Null),
1139            method,
1140            args,
1141            row_index,
1142        )
1143    }
1144
1145    /// Evaluate a method on a value
1146    fn evaluate_method_on_value(
1147        &mut self,
1148        value: &DataValue,
1149        method: &str,
1150        args: &[SqlExpression],
1151        row_index: usize,
1152    ) -> Result<DataValue> {
1153        // First, try to proxy the method through the function registry
1154        // Many string methods have corresponding functions (TRIM, LENGTH, CONTAINS, etc.)
1155
1156        // Map method names to function names (case-insensitive matching)
1157        let function_name = match method.to_lowercase().as_str() {
1158            "trim" => "TRIM",
1159            "trimstart" | "trimbegin" => "TRIMSTART",
1160            "trimend" => "TRIMEND",
1161            "length" | "len" => "LENGTH",
1162            "contains" => "CONTAINS",
1163            "startswith" => "STARTSWITH",
1164            "endswith" => "ENDSWITH",
1165            "indexof" => "INDEXOF",
1166            _ => method, // Try the method name as-is
1167        };
1168
1169        // Check if we have this function in the registry
1170        if self.function_registry.get(function_name).is_some() {
1171            debug!(
1172                "Proxying method '{}' through function registry as '{}'",
1173                method, function_name
1174            );
1175
1176            // Prepare arguments: receiver is the first argument, followed by method args
1177            let mut func_args = vec![value.clone()];
1178
1179            // Evaluate method arguments and add them
1180            for arg in args {
1181                func_args.push(self.evaluate(arg, row_index)?);
1182            }
1183
1184            // Get the function and call it
1185            let func = self.function_registry.get(function_name).unwrap();
1186            return func.evaluate(&func_args);
1187        }
1188
1189        // If not in registry, the method is not supported
1190        // All methods should be registered in the function registry
1191        Err(anyhow!(
1192            "Method '{}' not found. It should be registered in the function registry.",
1193            method
1194        ))
1195    }
1196
1197    /// Evaluate a CASE expression
1198    fn evaluate_case_expression(
1199        &mut self,
1200        when_branches: &[crate::sql::recursive_parser::WhenBranch],
1201        else_branch: &Option<Box<SqlExpression>>,
1202        row_index: usize,
1203    ) -> Result<DataValue> {
1204        debug!(
1205            "ArithmeticEvaluator: evaluating CASE expression for row {}",
1206            row_index
1207        );
1208
1209        // Evaluate each WHEN condition in order
1210        for branch in when_branches {
1211            // Evaluate the condition as a boolean
1212            let condition_result = self.evaluate_condition_as_bool(&branch.condition, row_index)?;
1213
1214            if condition_result {
1215                debug!("CASE: WHEN condition matched, evaluating result expression");
1216                return self.evaluate(&branch.result, row_index);
1217            }
1218        }
1219
1220        // If no WHEN condition matched, evaluate ELSE clause (or return NULL)
1221        if let Some(else_expr) = else_branch {
1222            debug!("CASE: No WHEN matched, evaluating ELSE expression");
1223            self.evaluate(else_expr, row_index)
1224        } else {
1225            debug!("CASE: No WHEN matched and no ELSE, returning NULL");
1226            Ok(DataValue::Null)
1227        }
1228    }
1229
1230    /// Evaluate a simple CASE expression
1231    fn evaluate_simple_case_expression(
1232        &mut self,
1233        expr: &Box<SqlExpression>,
1234        when_branches: &[crate::sql::parser::ast::SimpleWhenBranch],
1235        else_branch: &Option<Box<SqlExpression>>,
1236        row_index: usize,
1237    ) -> Result<DataValue> {
1238        debug!(
1239            "ArithmeticEvaluator: evaluating simple CASE expression for row {}",
1240            row_index
1241        );
1242
1243        // Evaluate the main expression once
1244        let case_value = self.evaluate(expr, row_index)?;
1245        debug!("Simple CASE: evaluated expression to {:?}", case_value);
1246
1247        // Compare against each WHEN value in order
1248        for branch in when_branches {
1249            // Evaluate the WHEN value
1250            let when_value = self.evaluate(&branch.value, row_index)?;
1251
1252            // Check for equality
1253            if self.values_equal(&case_value, &when_value)? {
1254                debug!("Simple CASE: WHEN value matched, evaluating result expression");
1255                return self.evaluate(&branch.result, row_index);
1256            }
1257        }
1258
1259        // If no WHEN value matched, evaluate ELSE clause (or return NULL)
1260        if let Some(else_expr) = else_branch {
1261            debug!("Simple CASE: No WHEN matched, evaluating ELSE expression");
1262            self.evaluate(else_expr, row_index)
1263        } else {
1264            debug!("Simple CASE: No WHEN matched and no ELSE, returning NULL");
1265            Ok(DataValue::Null)
1266        }
1267    }
1268
1269    /// Check if two DataValues are equal
1270    fn values_equal(&self, left: &DataValue, right: &DataValue) -> Result<bool> {
1271        match (left, right) {
1272            (DataValue::Null, DataValue::Null) => Ok(true),
1273            (DataValue::Null, _) | (_, DataValue::Null) => Ok(false),
1274            (DataValue::Integer(a), DataValue::Integer(b)) => Ok(a == b),
1275            (DataValue::Float(a), DataValue::Float(b)) => Ok((a - b).abs() < f64::EPSILON),
1276            (DataValue::String(a), DataValue::String(b)) => Ok(a == b),
1277            (DataValue::Boolean(a), DataValue::Boolean(b)) => Ok(a == b),
1278            (DataValue::DateTime(a), DataValue::DateTime(b)) => Ok(a == b),
1279            // Type coercion for numeric comparisons
1280            (DataValue::Integer(a), DataValue::Float(b)) => {
1281                Ok((*a as f64 - b).abs() < f64::EPSILON)
1282            }
1283            (DataValue::Float(a), DataValue::Integer(b)) => {
1284                Ok((a - *b as f64).abs() < f64::EPSILON)
1285            }
1286            _ => Ok(false),
1287        }
1288    }
1289
1290    /// Helper method to evaluate an expression as a boolean (for CASE WHEN conditions)
1291    fn evaluate_condition_as_bool(
1292        &mut self,
1293        expr: &SqlExpression,
1294        row_index: usize,
1295    ) -> Result<bool> {
1296        let value = self.evaluate(expr, row_index)?;
1297
1298        match value {
1299            DataValue::Boolean(b) => Ok(b),
1300            DataValue::Integer(i) => Ok(i != 0),
1301            DataValue::Float(f) => Ok(f != 0.0),
1302            DataValue::Null => Ok(false),
1303            DataValue::String(s) => Ok(!s.is_empty()),
1304            DataValue::InternedString(s) => Ok(!s.is_empty()),
1305            _ => Ok(true), // Other types are considered truthy
1306        }
1307    }
1308}
1309
1310#[cfg(test)]
1311mod tests {
1312    use super::*;
1313    use crate::data::datatable::{DataColumn, DataRow};
1314
1315    fn create_test_table() -> DataTable {
1316        let mut table = DataTable::new("test");
1317        table.add_column(DataColumn::new("a"));
1318        table.add_column(DataColumn::new("b"));
1319        table.add_column(DataColumn::new("c"));
1320
1321        table
1322            .add_row(DataRow::new(vec![
1323                DataValue::Integer(10),
1324                DataValue::Float(2.5),
1325                DataValue::Integer(4),
1326            ]))
1327            .unwrap();
1328
1329        table
1330    }
1331
1332    #[test]
1333    fn test_evaluate_column() {
1334        let table = create_test_table();
1335        let mut evaluator = ArithmeticEvaluator::new(&table);
1336
1337        let expr = SqlExpression::Column(ColumnRef::unquoted("a".to_string()));
1338        let result = evaluator.evaluate(&expr, 0).unwrap();
1339        assert_eq!(result, DataValue::Integer(10));
1340    }
1341
1342    #[test]
1343    fn test_evaluate_number_literal() {
1344        let table = create_test_table();
1345        let mut evaluator = ArithmeticEvaluator::new(&table);
1346
1347        let expr = SqlExpression::NumberLiteral("42".to_string());
1348        let result = evaluator.evaluate(&expr, 0).unwrap();
1349        assert_eq!(result, DataValue::Integer(42));
1350
1351        let expr = SqlExpression::NumberLiteral("3.14".to_string());
1352        let result = evaluator.evaluate(&expr, 0).unwrap();
1353        assert_eq!(result, DataValue::Float(3.14));
1354    }
1355
1356    #[test]
1357    fn test_add_values() {
1358        let table = create_test_table();
1359        let mut evaluator = ArithmeticEvaluator::new(&table);
1360
1361        // Integer + Integer
1362        let result = evaluator
1363            .add_values(&DataValue::Integer(5), &DataValue::Integer(3))
1364            .unwrap();
1365        assert_eq!(result, DataValue::Integer(8));
1366
1367        // Integer + Float
1368        let result = evaluator
1369            .add_values(&DataValue::Integer(5), &DataValue::Float(2.5))
1370            .unwrap();
1371        assert_eq!(result, DataValue::Float(7.5));
1372    }
1373
1374    #[test]
1375    fn test_multiply_values() {
1376        let table = create_test_table();
1377        let mut evaluator = ArithmeticEvaluator::new(&table);
1378
1379        // Integer * Float
1380        let result = evaluator
1381            .multiply_values(&DataValue::Integer(4), &DataValue::Float(2.5))
1382            .unwrap();
1383        assert_eq!(result, DataValue::Float(10.0));
1384    }
1385
1386    #[test]
1387    fn test_divide_values() {
1388        let table = create_test_table();
1389        let mut evaluator = ArithmeticEvaluator::new(&table);
1390
1391        // Exact division
1392        let result = evaluator
1393            .divide_values(&DataValue::Integer(10), &DataValue::Integer(2))
1394            .unwrap();
1395        assert_eq!(result, DataValue::Integer(5));
1396
1397        // Non-exact division
1398        let result = evaluator
1399            .divide_values(&DataValue::Integer(10), &DataValue::Integer(3))
1400            .unwrap();
1401        assert_eq!(result, DataValue::Float(10.0 / 3.0));
1402    }
1403
1404    #[test]
1405    fn test_division_by_zero() {
1406        let table = create_test_table();
1407        let mut evaluator = ArithmeticEvaluator::new(&table);
1408
1409        let result = evaluator.divide_values(&DataValue::Integer(10), &DataValue::Integer(0));
1410        assert!(result.is_err());
1411        assert!(result.unwrap_err().to_string().contains("Division by zero"));
1412    }
1413
1414    #[test]
1415    fn test_binary_op_expression() {
1416        let table = create_test_table();
1417        let mut evaluator = ArithmeticEvaluator::new(&table);
1418
1419        // a * b where a=10, b=2.5
1420        let expr = SqlExpression::BinaryOp {
1421            left: Box::new(SqlExpression::Column(ColumnRef::unquoted("a".to_string()))),
1422            op: "*".to_string(),
1423            right: Box::new(SqlExpression::Column(ColumnRef::unquoted("b".to_string()))),
1424        };
1425
1426        let result = evaluator.evaluate(&expr, 0).unwrap();
1427        assert_eq!(result, DataValue::Float(25.0));
1428    }
1429}