sql_cli/data/
arithmetic_evaluator.rs

1use crate::config::global::get_date_notation;
2use crate::data::data_view::DataView;
3use crate::data::datatable::{DataTable, DataValue};
4use crate::data::value_comparisons::compare_with_op;
5use crate::sql::aggregate_functions::AggregateFunctionRegistry; // New registry
6use crate::sql::aggregates::AggregateRegistry; // Old registry (for migration)
7use crate::sql::functions::FunctionRegistry;
8use crate::sql::parser::ast::WindowSpec;
9use crate::sql::recursive_parser::SqlExpression;
10use crate::sql::window_context::WindowContext;
11use crate::sql::window_functions::{ExpressionEvaluator, WindowFunctionRegistry};
12use anyhow::{anyhow, Result};
13use std::collections::{HashMap, HashSet};
14use std::sync::Arc;
15use tracing::debug;
16
17/// Evaluates SQL expressions to compute `DataValues` (for SELECT clauses)
18/// This is different from `RecursiveWhereEvaluator` which returns boolean
19pub struct ArithmeticEvaluator<'a> {
20    table: &'a DataTable,
21    date_notation: String,
22    function_registry: Arc<FunctionRegistry>,
23    aggregate_registry: Arc<AggregateRegistry>, // Old registry (being phased out)
24    new_aggregate_registry: Arc<AggregateFunctionRegistry>, // New registry
25    window_function_registry: Arc<WindowFunctionRegistry>,
26    visible_rows: Option<Vec<usize>>, // For aggregate functions on filtered views
27    window_contexts: HashMap<String, Arc<WindowContext>>, // Cache window contexts by spec
28    table_aliases: HashMap<String, String>, // Map alias -> table name for qualified columns
29}
30
31impl<'a> ArithmeticEvaluator<'a> {
32    #[must_use]
33    pub fn new(table: &'a DataTable) -> Self {
34        Self {
35            table,
36            date_notation: get_date_notation(),
37            function_registry: Arc::new(FunctionRegistry::new()),
38            aggregate_registry: Arc::new(AggregateRegistry::new()),
39            new_aggregate_registry: Arc::new(AggregateFunctionRegistry::new()),
40            window_function_registry: Arc::new(WindowFunctionRegistry::new()),
41            visible_rows: None,
42            window_contexts: HashMap::new(),
43            table_aliases: HashMap::new(),
44        }
45    }
46
47    #[must_use]
48    pub fn with_date_notation(table: &'a DataTable, date_notation: String) -> Self {
49        Self {
50            table,
51            date_notation,
52            function_registry: Arc::new(FunctionRegistry::new()),
53            aggregate_registry: Arc::new(AggregateRegistry::new()),
54            new_aggregate_registry: Arc::new(AggregateFunctionRegistry::new()),
55            window_function_registry: Arc::new(WindowFunctionRegistry::new()),
56            visible_rows: None,
57            window_contexts: HashMap::new(),
58            table_aliases: HashMap::new(),
59        }
60    }
61
62    /// Set visible rows for aggregate functions (for filtered views)
63    #[must_use]
64    pub fn with_visible_rows(mut self, rows: Vec<usize>) -> Self {
65        self.visible_rows = Some(rows);
66        self
67    }
68
69    /// Set table aliases for qualified column resolution
70    #[must_use]
71    pub fn with_table_aliases(mut self, aliases: HashMap<String, String>) -> Self {
72        self.table_aliases = aliases;
73        self
74    }
75
76    #[must_use]
77    pub fn with_date_notation_and_registry(
78        table: &'a DataTable,
79        date_notation: String,
80        function_registry: Arc<FunctionRegistry>,
81    ) -> Self {
82        Self {
83            table,
84            date_notation,
85            function_registry,
86            aggregate_registry: Arc::new(AggregateRegistry::new()),
87            new_aggregate_registry: Arc::new(AggregateFunctionRegistry::new()),
88            window_function_registry: Arc::new(WindowFunctionRegistry::new()),
89            visible_rows: None,
90            window_contexts: HashMap::new(),
91            table_aliases: HashMap::new(),
92        }
93    }
94
95    /// Find a column name similar to the given name using edit distance
96    fn find_similar_column(&self, name: &str) -> Option<String> {
97        let columns = self.table.column_names();
98        let mut best_match: Option<(String, usize)> = None;
99
100        for col in columns {
101            let distance = self.edit_distance(&col.to_lowercase(), &name.to_lowercase());
102            // Only suggest if distance is small (likely a typo)
103            // Allow up to 3 edits for longer names
104            let max_distance = if name.len() > 10 { 3 } else { 2 };
105            if distance <= max_distance {
106                match &best_match {
107                    None => best_match = Some((col, distance)),
108                    Some((_, best_dist)) if distance < *best_dist => {
109                        best_match = Some((col, distance));
110                    }
111                    _ => {}
112                }
113            }
114        }
115
116        best_match.map(|(name, _)| name)
117    }
118
119    /// Calculate Levenshtein edit distance between two strings
120    fn edit_distance(&self, s1: &str, s2: &str) -> usize {
121        // Use the shared implementation from string_methods
122        crate::sql::functions::string_methods::EditDistanceFunction::calculate_edit_distance(s1, s2)
123    }
124
125    /// Evaluate an SQL expression to produce a `DataValue`
126    pub fn evaluate(&mut self, expr: &SqlExpression, row_index: usize) -> Result<DataValue> {
127        debug!(
128            "ArithmeticEvaluator: evaluating {:?} for row {}",
129            expr, row_index
130        );
131
132        match expr {
133            SqlExpression::Column(column_name) => self.evaluate_column(column_name, row_index),
134            SqlExpression::StringLiteral(s) => Ok(DataValue::String(s.clone())),
135            SqlExpression::BooleanLiteral(b) => Ok(DataValue::Boolean(*b)),
136            SqlExpression::NumberLiteral(n) => self.evaluate_number_literal(n),
137            SqlExpression::Null => Ok(DataValue::Null),
138            SqlExpression::BinaryOp { left, op, right } => {
139                self.evaluate_binary_op(left, op, right, row_index)
140            }
141            SqlExpression::FunctionCall {
142                name,
143                args,
144                distinct,
145            } => self.evaluate_function_with_distinct(name, args, *distinct, row_index),
146            SqlExpression::WindowFunction {
147                name,
148                args,
149                window_spec,
150            } => self.evaluate_window_function(name, args, window_spec, row_index),
151            SqlExpression::MethodCall {
152                object,
153                method,
154                args,
155            } => self.evaluate_method_call(object, method, args, row_index),
156            SqlExpression::ChainedMethodCall { base, method, args } => {
157                // Evaluate the base expression first, then apply the method
158                let base_value = self.evaluate(base, row_index)?;
159                self.evaluate_method_on_value(&base_value, method, args, row_index)
160            }
161            SqlExpression::CaseExpression {
162                when_branches,
163                else_branch,
164            } => self.evaluate_case_expression(when_branches, else_branch, row_index),
165            SqlExpression::SimpleCaseExpression {
166                expr,
167                when_branches,
168                else_branch,
169            } => self.evaluate_simple_case_expression(expr, when_branches, else_branch, row_index),
170            _ => Err(anyhow!(
171                "Unsupported expression type for arithmetic evaluation: {:?}",
172                expr
173            )),
174        }
175    }
176
177    /// Evaluate a column reference
178    fn evaluate_column(&self, column_name: &str, row_index: usize) -> Result<DataValue> {
179        // First try to resolve qualified column names (table.column or alias.column)
180        let resolved_column = if column_name.contains('.') {
181            // Split on last dot to handle cases like "schema.table.column"
182            if let Some(dot_pos) = column_name.rfind('.') {
183                let _table_or_alias = &column_name[..dot_pos];
184                let col_name = &column_name[dot_pos + 1..];
185
186                // For now, just use the column name part
187                // In the future, we could validate the table/alias part
188                debug!(
189                    "Resolving qualified column: {} -> {}",
190                    column_name, col_name
191                );
192                col_name.to_string()
193            } else {
194                column_name.to_string()
195            }
196        } else {
197            column_name.to_string()
198        };
199
200        let col_index = if let Some(idx) = self.table.get_column_index(&resolved_column) {
201            idx
202        } else if resolved_column != column_name {
203            // If not found, try the original name
204            if let Some(idx) = self.table.get_column_index(column_name) {
205                idx
206            } else {
207                let suggestion = self.find_similar_column(&resolved_column);
208                return Err(match suggestion {
209                    Some(similar) => anyhow!(
210                        "Column '{}' not found. Did you mean '{}'?",
211                        column_name,
212                        similar
213                    ),
214                    None => anyhow!("Column '{}' not found", column_name),
215                });
216            }
217        } else {
218            let suggestion = self.find_similar_column(&resolved_column);
219            return Err(match suggestion {
220                Some(similar) => anyhow!(
221                    "Column '{}' not found. Did you mean '{}'?",
222                    column_name,
223                    similar
224                ),
225                None => anyhow!("Column '{}' not found", column_name),
226            });
227        };
228
229        if row_index >= self.table.row_count() {
230            return Err(anyhow!("Row index {} out of bounds", row_index));
231        }
232
233        let row = self
234            .table
235            .get_row(row_index)
236            .ok_or_else(|| anyhow!("Row {} not found", row_index))?;
237
238        let value = row
239            .get(col_index)
240            .ok_or_else(|| anyhow!("Column index {} out of bounds for row", col_index))?;
241
242        Ok(value.clone())
243    }
244
245    /// Evaluate a number literal (handles both integers and floats)
246    fn evaluate_number_literal(&self, number_str: &str) -> Result<DataValue> {
247        // Try to parse as integer first
248        if let Ok(int_val) = number_str.parse::<i64>() {
249            return Ok(DataValue::Integer(int_val));
250        }
251
252        // If that fails, try as float
253        if let Ok(float_val) = number_str.parse::<f64>() {
254            return Ok(DataValue::Float(float_val));
255        }
256
257        Err(anyhow!("Invalid number literal: {}", number_str))
258    }
259
260    /// Evaluate a binary operation (arithmetic)
261    fn evaluate_binary_op(
262        &mut self,
263        left: &SqlExpression,
264        op: &str,
265        right: &SqlExpression,
266        row_index: usize,
267    ) -> Result<DataValue> {
268        let left_val = self.evaluate(left, row_index)?;
269        let right_val = self.evaluate(right, row_index)?;
270
271        debug!(
272            "ArithmeticEvaluator: {} {} {}",
273            self.format_value(&left_val),
274            op,
275            self.format_value(&right_val)
276        );
277
278        match op {
279            "+" => self.add_values(&left_val, &right_val),
280            "-" => self.subtract_values(&left_val, &right_val),
281            "*" => self.multiply_values(&left_val, &right_val),
282            "/" => self.divide_values(&left_val, &right_val),
283            "%" => {
284                // Modulo operator - call MOD function
285                let args = vec![left.clone(), right.clone()];
286                self.evaluate_function("MOD", &args, row_index)
287            }
288            // Comparison operators (return boolean results)
289            // Use centralized comparison logic for consistency
290            ">" | "<" | ">=" | "<=" | "=" | "!=" | "<>" => {
291                let result = compare_with_op(&left_val, &right_val, op, false);
292                Ok(DataValue::Boolean(result))
293            }
294            // IS NULL / IS NOT NULL operators
295            "IS NULL" => Ok(DataValue::Boolean(matches!(left_val, DataValue::Null))),
296            "IS NOT NULL" => Ok(DataValue::Boolean(!matches!(left_val, DataValue::Null))),
297            // Logical operators
298            "AND" => {
299                let left_bool = self.to_bool(&left_val)?;
300                let right_bool = self.to_bool(&right_val)?;
301                Ok(DataValue::Boolean(left_bool && right_bool))
302            }
303            "OR" => {
304                let left_bool = self.to_bool(&left_val)?;
305                let right_bool = self.to_bool(&right_val)?;
306                Ok(DataValue::Boolean(left_bool || right_bool))
307            }
308            _ => Err(anyhow!("Unsupported arithmetic operator: {}", op)),
309        }
310    }
311
312    /// Add two `DataValues` with type coercion
313    fn add_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
314        // NULL handling - any operation with NULL returns NULL
315        if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
316            return Ok(DataValue::Null);
317        }
318
319        match (left, right) {
320            (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a + b)),
321            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 + b)),
322            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a + *b as f64)),
323            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a + b)),
324            _ => Err(anyhow!("Cannot add {:?} and {:?}", left, right)),
325        }
326    }
327
328    /// Subtract two `DataValues` with type coercion
329    fn subtract_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
330        // NULL handling - any operation with NULL returns NULL
331        if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
332            return Ok(DataValue::Null);
333        }
334
335        match (left, right) {
336            (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a - b)),
337            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 - b)),
338            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a - *b as f64)),
339            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a - b)),
340            _ => Err(anyhow!("Cannot subtract {:?} and {:?}", left, right)),
341        }
342    }
343
344    /// Multiply two `DataValues` with type coercion
345    fn multiply_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
346        // NULL handling - any operation with NULL returns NULL
347        if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
348            return Ok(DataValue::Null);
349        }
350
351        match (left, right) {
352            (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a * b)),
353            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 * b)),
354            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a * *b as f64)),
355            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a * b)),
356            _ => Err(anyhow!("Cannot multiply {:?} and {:?}", left, right)),
357        }
358    }
359
360    /// Divide two `DataValues` with type coercion
361    fn divide_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
362        // NULL handling - any operation with NULL returns NULL
363        if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
364            return Ok(DataValue::Null);
365        }
366
367        // Check for division by zero first
368        let is_zero = match right {
369            DataValue::Integer(0) => true,
370            DataValue::Float(f) if *f == 0.0 => true, // Only check for exact zero, not epsilon
371            _ => false,
372        };
373
374        if is_zero {
375            return Err(anyhow!("Division by zero"));
376        }
377
378        match (left, right) {
379            (DataValue::Integer(a), DataValue::Integer(b)) => {
380                // Integer division - if result is exact, keep as int, otherwise promote to float
381                if a % b == 0 {
382                    Ok(DataValue::Integer(a / b))
383                } else {
384                    Ok(DataValue::Float(*a as f64 / *b as f64))
385                }
386            }
387            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 / b)),
388            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a / *b as f64)),
389            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a / b)),
390            _ => Err(anyhow!("Cannot divide {:?} and {:?}", left, right)),
391        }
392    }
393
394    /// Format a `DataValue` for debug output
395    fn format_value(&self, value: &DataValue) -> String {
396        match value {
397            DataValue::Integer(i) => i.to_string(),
398            DataValue::Float(f) => f.to_string(),
399            DataValue::String(s) => format!("'{s}'"),
400            _ => format!("{value:?}"),
401        }
402    }
403
404    /// Convert a DataValue to boolean for logical operations
405    fn to_bool(&self, value: &DataValue) -> Result<bool> {
406        match value {
407            DataValue::Boolean(b) => Ok(*b),
408            DataValue::Integer(i) => Ok(*i != 0),
409            DataValue::Float(f) => Ok(*f != 0.0),
410            DataValue::Null => Ok(false),
411            _ => Err(anyhow!("Cannot convert {:?} to boolean", value)),
412        }
413    }
414
415    /// Evaluate a function call
416    fn evaluate_function_with_distinct(
417        &mut self,
418        name: &str,
419        args: &[SqlExpression],
420        distinct: bool,
421        row_index: usize,
422    ) -> Result<DataValue> {
423        // If DISTINCT is specified, handle it specially for aggregate functions
424        if distinct {
425            let name_upper = name.to_uppercase();
426
427            // Check if it's an aggregate function in either registry
428            if self.aggregate_registry.is_aggregate(&name_upper)
429                || self.new_aggregate_registry.contains(&name_upper)
430            {
431                return self.evaluate_aggregate_with_distinct(&name_upper, args, row_index);
432            } else {
433                return Err(anyhow!(
434                    "DISTINCT can only be used with aggregate functions"
435                ));
436            }
437        }
438
439        // Otherwise, use the regular evaluation
440        self.evaluate_function(name, args, row_index)
441    }
442
443    fn evaluate_aggregate_with_distinct(
444        &mut self,
445        name: &str,
446        args: &[SqlExpression],
447        _row_index: usize,
448    ) -> Result<DataValue> {
449        let name_upper = name.to_uppercase();
450
451        // Check new aggregate registry first for migrated functions
452        if self.new_aggregate_registry.get(&name_upper).is_some() {
453            let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
454                visible.clone()
455            } else {
456                (0..self.table.rows.len()).collect()
457            };
458
459            // Collect and deduplicate values for DISTINCT
460            let mut vals = Vec::new();
461            for &row_idx in &rows_to_process {
462                if !args.is_empty() {
463                    let value = self.evaluate(&args[0], row_idx)?;
464                    vals.push(value);
465                }
466            }
467
468            // Deduplicate values
469            let mut seen = HashSet::new();
470            let unique_values: Vec<_> = vals
471                .into_iter()
472                .filter(|v| {
473                    let key = format!("{:?}", v);
474                    seen.insert(key)
475                })
476                .collect();
477
478            // Get the aggregate function from the new registry
479            let agg_func = self.new_aggregate_registry.get(&name_upper).unwrap();
480            let mut state = agg_func.create_state();
481
482            // Use unique values
483            for value in &unique_values {
484                state.accumulate(value)?;
485            }
486
487            return Ok(state.finalize());
488        }
489
490        // Check old aggregate registry (DISTINCT handling)
491        if self.aggregate_registry.get(&name_upper).is_some() {
492            // Determine which rows to process first
493            let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
494                visible.clone()
495            } else {
496                (0..self.table.rows.len()).collect()
497            };
498
499            // Special handling for STRING_AGG with separator parameter
500            if name_upper == "STRING_AGG" && args.len() >= 2 {
501                // STRING_AGG(DISTINCT column, separator)
502                let mut state = crate::sql::aggregates::AggregateState::StringAgg(
503                    // Evaluate the separator (second argument) once
504                    if args.len() >= 2 {
505                        let separator = self.evaluate(&args[1], 0)?; // Separator doesn't depend on row
506                        match separator {
507                            DataValue::String(s) => crate::sql::aggregates::StringAggState::new(&s),
508                            DataValue::InternedString(s) => {
509                                crate::sql::aggregates::StringAggState::new(&s)
510                            }
511                            _ => crate::sql::aggregates::StringAggState::new(","), // Default separator
512                        }
513                    } else {
514                        crate::sql::aggregates::StringAggState::new(",")
515                    },
516                );
517
518                // Evaluate the first argument (column) for each row and accumulate
519                // Handle DISTINCT - use a HashSet to track seen values
520                let mut seen_values = HashSet::new();
521
522                for &row_idx in &rows_to_process {
523                    let value = self.evaluate(&args[0], row_idx)?;
524
525                    // Skip if we've seen this value
526                    if !seen_values.insert(value.clone()) {
527                        continue; // Skip duplicate values
528                    }
529
530                    // Now get the aggregate function and accumulate
531                    let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
532                    agg_func.accumulate(&mut state, &value)?;
533                }
534
535                // Finalize the aggregate
536                let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
537                return Ok(agg_func.finalize(state));
538            }
539
540            // For other aggregates with DISTINCT
541            // Evaluate the argument expression for each row
542            let mut vals = Vec::new();
543            for &row_idx in &rows_to_process {
544                if !args.is_empty() {
545                    let value = self.evaluate(&args[0], row_idx)?;
546                    vals.push(value);
547                }
548            }
549
550            // Deduplicate values for DISTINCT
551            let mut seen = HashSet::new();
552            let mut unique_values = Vec::new();
553            for value in vals {
554                if seen.insert(value.clone()) {
555                    unique_values.push(value);
556                }
557            }
558
559            // Now get the aggregate function and process
560            let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
561            let mut state = agg_func.init();
562
563            // Use unique values
564            for value in &unique_values {
565                agg_func.accumulate(&mut state, value)?;
566            }
567
568            return Ok(agg_func.finalize(state));
569        }
570
571        Err(anyhow!("Unknown aggregate function: {}", name))
572    }
573
574    fn evaluate_function(
575        &mut self,
576        name: &str,
577        args: &[SqlExpression],
578        row_index: usize,
579    ) -> Result<DataValue> {
580        // Check if this is an aggregate function
581        let name_upper = name.to_uppercase();
582
583        // Check new aggregate registry first (for migrated functions)
584        if self.new_aggregate_registry.get(&name_upper).is_some() {
585            // Use new registry for SUM
586            let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
587                visible.clone()
588            } else {
589                (0..self.table.rows.len()).collect()
590            };
591
592            // Get the aggregate function from the new registry
593            let agg_func = self.new_aggregate_registry.get(&name_upper).unwrap();
594            let mut state = agg_func.create_state();
595
596            // Special handling for COUNT(*)
597            if name_upper == "COUNT" || name_upper == "COUNT_STAR" {
598                if args.is_empty()
599                    || (args.len() == 1
600                        && matches!(&args[0], SqlExpression::Column(col) if col == "*"))
601                    || (args.len() == 1
602                        && matches!(&args[0], SqlExpression::StringLiteral(s) if s == "*"))
603                {
604                    // COUNT(*) or COUNT_STAR - count all rows
605                    for _ in &rows_to_process {
606                        state.accumulate(&DataValue::Integer(1))?;
607                    }
608                } else {
609                    // COUNT(column) - count non-null values
610                    for &row_idx in &rows_to_process {
611                        let value = self.evaluate(&args[0], row_idx)?;
612                        state.accumulate(&value)?;
613                    }
614                }
615            } else {
616                // Other aggregates - evaluate arguments and accumulate
617                if !args.is_empty() {
618                    for &row_idx in &rows_to_process {
619                        let value = self.evaluate(&args[0], row_idx)?;
620                        state.accumulate(&value)?;
621                    }
622                }
623            }
624
625            return Ok(state.finalize());
626        }
627
628        // Check old aggregate registry (for non-migrated functions)
629        if self.aggregate_registry.get(&name_upper).is_some() {
630            // Determine which rows to process first
631            let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
632                visible.clone()
633            } else {
634                (0..self.table.rows.len()).collect()
635            };
636
637            // Special handling for STRING_AGG with separator parameter
638            if name_upper == "STRING_AGG" && args.len() >= 2 {
639                // STRING_AGG(column, separator) - without DISTINCT (handled separately)
640                let mut state = crate::sql::aggregates::AggregateState::StringAgg(
641                    // Evaluate the separator (second argument) once
642                    if args.len() >= 2 {
643                        let separator = self.evaluate(&args[1], 0)?; // Separator doesn't depend on row
644                        match separator {
645                            DataValue::String(s) => crate::sql::aggregates::StringAggState::new(&s),
646                            DataValue::InternedString(s) => {
647                                crate::sql::aggregates::StringAggState::new(&s)
648                            }
649                            _ => crate::sql::aggregates::StringAggState::new(","), // Default separator
650                        }
651                    } else {
652                        crate::sql::aggregates::StringAggState::new(",")
653                    },
654                );
655
656                // Evaluate the first argument (column) for each row and accumulate
657                for &row_idx in &rows_to_process {
658                    let value = self.evaluate(&args[0], row_idx)?;
659                    // Now get the aggregate function and accumulate
660                    let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
661                    agg_func.accumulate(&mut state, &value)?;
662                }
663
664                // Finalize the aggregate
665                let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
666                return Ok(agg_func.finalize(state));
667            }
668
669            // Evaluate arguments first if needed (to avoid borrow issues)
670            let values = if !args.is_empty()
671                && !(args.len() == 1 && matches!(&args[0], SqlExpression::Column(c) if c == "*"))
672            {
673                // Evaluate the argument expression for each row
674                let mut vals = Vec::new();
675                for &row_idx in &rows_to_process {
676                    let value = self.evaluate(&args[0], row_idx)?;
677                    vals.push(value);
678                }
679                Some(vals)
680            } else {
681                None
682            };
683
684            // Now get the aggregate function and process
685            let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
686            let mut state = agg_func.init();
687
688            if let Some(values) = values {
689                // Use evaluated values (DISTINCT is handled in evaluate_aggregate_with_distinct)
690                for value in &values {
691                    agg_func.accumulate(&mut state, value)?;
692                }
693            } else {
694                // COUNT(*) case
695                for _ in &rows_to_process {
696                    agg_func.accumulate(&mut state, &DataValue::Integer(1))?;
697                }
698            }
699
700            return Ok(agg_func.finalize(state));
701        }
702
703        // First check if this function exists in the registry
704        if self.function_registry.get(name).is_some() {
705            // Evaluate all arguments first to avoid borrow issues
706            let mut evaluated_args = Vec::new();
707            for arg in args {
708                evaluated_args.push(self.evaluate(arg, row_index)?);
709            }
710
711            // Get the function and call it
712            let func = self.function_registry.get(name).unwrap();
713            return func.evaluate(&evaluated_args);
714        }
715
716        // If not in registry, return error for unknown function
717        Err(anyhow!("Unknown function: {}", name))
718    }
719
720    /// Get or create a WindowContext for the given specification
721    fn get_or_create_window_context(&mut self, spec: &WindowSpec) -> Result<Arc<WindowContext>> {
722        // Create a key for caching based on the spec
723        let key = format!("{:?}", spec);
724
725        if let Some(context) = self.window_contexts.get(&key) {
726            return Ok(Arc::clone(context));
727        }
728
729        // Create a DataView from the table (with visible rows if filtered)
730        let data_view = if let Some(ref _visible_rows) = self.visible_rows {
731            // Create a filtered view
732            let view = DataView::new(Arc::new(self.table.clone()));
733            // Apply filtering based on visible rows
734            // Note: This is a simplified approach - in production we'd need proper filtering
735            view
736        } else {
737            DataView::new(Arc::new(self.table.clone()))
738        };
739
740        // Create the WindowContext with the full spec (including frame)
741        let context = WindowContext::new_with_spec(Arc::new(data_view), spec.clone())?;
742
743        let context = Arc::new(context);
744        self.window_contexts.insert(key, Arc::clone(&context));
745        Ok(context)
746    }
747
748    /// Evaluate a window function
749    fn evaluate_window_function(
750        &mut self,
751        name: &str,
752        args: &[SqlExpression],
753        spec: &WindowSpec,
754        row_index: usize,
755    ) -> Result<DataValue> {
756        let name_upper = name.to_uppercase();
757
758        // First check if this is a syntactic sugar function in the registry
759        debug!("Looking for window function {} in registry", name_upper);
760        if let Some(window_fn_arc) = self.window_function_registry.get(&name_upper) {
761            debug!("Found window function {} in registry", name_upper);
762
763            // Dereference to get the actual window function
764            let window_fn = window_fn_arc.as_ref();
765
766            // Validate arguments
767            window_fn.validate_args(args)?;
768
769            // Transform the window spec based on the function's requirements
770            let transformed_spec = window_fn.transform_window_spec(spec, args)?;
771
772            // Get or create the window context with the transformed spec
773            let context = self.get_or_create_window_context(&transformed_spec)?;
774
775            // Create an expression evaluator adapter
776            struct EvaluatorAdapter<'a, 'b> {
777                evaluator: &'a mut ArithmeticEvaluator<'b>,
778                row_index: usize,
779            }
780
781            impl<'a, 'b> ExpressionEvaluator for EvaluatorAdapter<'a, 'b> {
782                fn evaluate(
783                    &mut self,
784                    expr: &SqlExpression,
785                    _row_index: usize,
786                ) -> Result<DataValue> {
787                    self.evaluator.evaluate(expr, self.row_index)
788                }
789            }
790
791            let mut adapter = EvaluatorAdapter {
792                evaluator: self,
793                row_index,
794            };
795
796            // Call the window function's compute method
797            return window_fn.compute(&context, row_index, args, &mut adapter);
798        }
799
800        // Fall back to built-in window functions
801        let context = self.get_or_create_window_context(spec)?;
802
803        match name_upper.as_str() {
804            "LAG" => {
805                // LAG(column, offset, default)
806                if args.is_empty() {
807                    return Err(anyhow!("LAG requires at least 1 argument"));
808                }
809
810                // Get column name
811                let column = match &args[0] {
812                    SqlExpression::Column(col) => col.clone(),
813                    _ => return Err(anyhow!("LAG first argument must be a column")),
814                };
815
816                // Get offset (default 1)
817                let offset = if args.len() > 1 {
818                    match self.evaluate(&args[1], row_index)? {
819                        DataValue::Integer(i) => i as i32,
820                        _ => return Err(anyhow!("LAG offset must be an integer")),
821                    }
822                } else {
823                    1
824                };
825
826                // Get value at offset
827                Ok(context
828                    .get_offset_value(row_index, -offset, &column)
829                    .unwrap_or(DataValue::Null))
830            }
831            "LEAD" => {
832                // LEAD(column, offset, default)
833                if args.is_empty() {
834                    return Err(anyhow!("LEAD requires at least 1 argument"));
835                }
836
837                // Get column name
838                let column = match &args[0] {
839                    SqlExpression::Column(col) => col.clone(),
840                    _ => return Err(anyhow!("LEAD first argument must be a column")),
841                };
842
843                // Get offset (default 1)
844                let offset = if args.len() > 1 {
845                    match self.evaluate(&args[1], row_index)? {
846                        DataValue::Integer(i) => i as i32,
847                        _ => return Err(anyhow!("LEAD offset must be an integer")),
848                    }
849                } else {
850                    1
851                };
852
853                // Get value at offset
854                Ok(context
855                    .get_offset_value(row_index, offset, &column)
856                    .unwrap_or(DataValue::Null))
857            }
858            "ROW_NUMBER" => {
859                // ROW_NUMBER() - no arguments
860                Ok(DataValue::Integer(context.get_row_number(row_index) as i64))
861            }
862            "FIRST_VALUE" => {
863                // FIRST_VALUE(column) OVER (... ROWS ...)
864                if args.is_empty() {
865                    return Err(anyhow!("FIRST_VALUE requires 1 argument"));
866                }
867
868                let column = match &args[0] {
869                    SqlExpression::Column(col) => col.clone(),
870                    _ => return Err(anyhow!("FIRST_VALUE argument must be a column")),
871                };
872
873                // Use frame-aware version if frame is specified
874                if context.has_frame() {
875                    Ok(context
876                        .get_frame_first_value(row_index, &column)
877                        .unwrap_or(DataValue::Null))
878                } else {
879                    Ok(context
880                        .get_first_value(row_index, &column)
881                        .unwrap_or(DataValue::Null))
882                }
883            }
884            "LAST_VALUE" => {
885                // LAST_VALUE(column) OVER (... ROWS ...)
886                if args.is_empty() {
887                    return Err(anyhow!("LAST_VALUE requires 1 argument"));
888                }
889
890                let column = match &args[0] {
891                    SqlExpression::Column(col) => col.clone(),
892                    _ => return Err(anyhow!("LAST_VALUE argument must be a column")),
893                };
894
895                // Use frame-aware version if frame is specified
896                if context.has_frame() {
897                    Ok(context
898                        .get_frame_last_value(row_index, &column)
899                        .unwrap_or(DataValue::Null))
900                } else {
901                    Ok(context
902                        .get_last_value(row_index, &column)
903                        .unwrap_or(DataValue::Null))
904                }
905            }
906            "SUM" => {
907                // SUM(column) OVER (PARTITION BY ... ROWS n PRECEDING)
908                if args.is_empty() {
909                    return Err(anyhow!("SUM requires 1 argument"));
910                }
911
912                let column = match &args[0] {
913                    SqlExpression::Column(col) => col.clone(),
914                    _ => return Err(anyhow!("SUM argument must be a column")),
915                };
916
917                // Use frame-aware sum if frame is specified, otherwise use partition sum
918                if context.has_frame() {
919                    Ok(context
920                        .get_frame_sum(row_index, &column)
921                        .unwrap_or(DataValue::Null))
922                } else {
923                    Ok(context
924                        .get_partition_sum(row_index, &column)
925                        .unwrap_or(DataValue::Null))
926                }
927            }
928            "AVG" => {
929                // AVG(column) OVER (PARTITION BY ... ROWS n PRECEDING)
930                if args.is_empty() {
931                    return Err(anyhow!("AVG requires 1 argument"));
932                }
933
934                let column = match &args[0] {
935                    SqlExpression::Column(col) => col.clone(),
936                    _ => return Err(anyhow!("AVG argument must be a column")),
937                };
938
939                Ok(context
940                    .get_frame_avg(row_index, &column)
941                    .unwrap_or(DataValue::Null))
942            }
943            "STDDEV" | "STDEV" => {
944                // STDDEV(column) OVER (PARTITION BY ... ROWS n PRECEDING)
945                if args.is_empty() {
946                    return Err(anyhow!("STDDEV requires 1 argument"));
947                }
948
949                let column = match &args[0] {
950                    SqlExpression::Column(col) => col.clone(),
951                    _ => return Err(anyhow!("STDDEV argument must be a column")),
952                };
953
954                Ok(context
955                    .get_frame_stddev(row_index, &column)
956                    .unwrap_or(DataValue::Null))
957            }
958            "VARIANCE" | "VAR" => {
959                // VARIANCE(column) OVER (PARTITION BY ... ROWS n PRECEDING)
960                if args.is_empty() {
961                    return Err(anyhow!("VARIANCE requires 1 argument"));
962                }
963
964                let column = match &args[0] {
965                    SqlExpression::Column(col) => col.clone(),
966                    _ => return Err(anyhow!("VARIANCE argument must be a column")),
967                };
968
969                Ok(context
970                    .get_frame_variance(row_index, &column)
971                    .unwrap_or(DataValue::Null))
972            }
973            "MIN" => {
974                // MIN(column) OVER (PARTITION BY ... ROWS n PRECEDING)
975                if args.is_empty() {
976                    return Err(anyhow!("MIN requires 1 argument"));
977                }
978
979                let column = match &args[0] {
980                    SqlExpression::Column(col) => col.clone(),
981                    _ => return Err(anyhow!("MIN argument must be a column")),
982                };
983
984                let frame_rows = context.get_frame_rows(row_index);
985                if frame_rows.is_empty() {
986                    return Ok(DataValue::Null);
987                }
988
989                let source_table = context.source();
990                let col_idx = source_table
991                    .get_column_index(&column)
992                    .ok_or_else(|| anyhow!("Column '{}' not found", column))?;
993
994                let mut min_value: Option<DataValue> = None;
995                for &row_idx in &frame_rows {
996                    if let Some(value) = source_table.get_value(row_idx, col_idx) {
997                        if !matches!(value, DataValue::Null) {
998                            match &min_value {
999                                None => min_value = Some(value.clone()),
1000                                Some(current_min) => {
1001                                    if value < current_min {
1002                                        min_value = Some(value.clone());
1003                                    }
1004                                }
1005                            }
1006                        }
1007                    }
1008                }
1009
1010                Ok(min_value.unwrap_or(DataValue::Null))
1011            }
1012            "MAX" => {
1013                // MAX(column) OVER (PARTITION BY ... ROWS n PRECEDING)
1014                if args.is_empty() {
1015                    return Err(anyhow!("MAX requires 1 argument"));
1016                }
1017
1018                let column = match &args[0] {
1019                    SqlExpression::Column(col) => col.clone(),
1020                    _ => return Err(anyhow!("MAX argument must be a column")),
1021                };
1022
1023                let frame_rows = context.get_frame_rows(row_index);
1024                if frame_rows.is_empty() {
1025                    return Ok(DataValue::Null);
1026                }
1027
1028                let source_table = context.source();
1029                let col_idx = source_table
1030                    .get_column_index(&column)
1031                    .ok_or_else(|| anyhow!("Column '{}' not found", column))?;
1032
1033                let mut max_value: Option<DataValue> = None;
1034                for &row_idx in &frame_rows {
1035                    if let Some(value) = source_table.get_value(row_idx, col_idx) {
1036                        if !matches!(value, DataValue::Null) {
1037                            match &max_value {
1038                                None => max_value = Some(value.clone()),
1039                                Some(current_max) => {
1040                                    if value > current_max {
1041                                        max_value = Some(value.clone());
1042                                    }
1043                                }
1044                            }
1045                        }
1046                    }
1047                }
1048
1049                Ok(max_value.unwrap_or(DataValue::Null))
1050            }
1051            "COUNT" => {
1052                // COUNT(*) or COUNT(column) OVER (PARTITION BY ... ROWS n PRECEDING)
1053                // Use frame-aware count if frame is specified, otherwise use partition count
1054
1055                if args.is_empty() {
1056                    // COUNT(*) OVER (...)
1057                    if context.has_frame() {
1058                        Ok(context
1059                            .get_frame_count(row_index, None)
1060                            .unwrap_or(DataValue::Null))
1061                    } else {
1062                        Ok(context
1063                            .get_partition_count(row_index, None)
1064                            .unwrap_or(DataValue::Null))
1065                    }
1066                } else {
1067                    // Check for COUNT(*)
1068                    let column = match &args[0] {
1069                        SqlExpression::Column(col) => {
1070                            if col == "*" {
1071                                // COUNT(*) - count all rows
1072                                if context.has_frame() {
1073                                    return Ok(context
1074                                        .get_frame_count(row_index, None)
1075                                        .unwrap_or(DataValue::Null));
1076                                } else {
1077                                    return Ok(context
1078                                        .get_partition_count(row_index, None)
1079                                        .unwrap_or(DataValue::Null));
1080                                }
1081                            }
1082                            col.clone()
1083                        }
1084                        SqlExpression::StringLiteral(s) if s == "*" => {
1085                            // COUNT(*) as StringLiteral
1086                            if context.has_frame() {
1087                                return Ok(context
1088                                    .get_frame_count(row_index, None)
1089                                    .unwrap_or(DataValue::Null));
1090                            } else {
1091                                return Ok(context
1092                                    .get_partition_count(row_index, None)
1093                                    .unwrap_or(DataValue::Null));
1094                            }
1095                        }
1096                        _ => return Err(anyhow!("COUNT argument must be a column or *")),
1097                    };
1098
1099                    // COUNT(column) - count non-null values
1100                    if context.has_frame() {
1101                        Ok(context
1102                            .get_frame_count(row_index, Some(&column))
1103                            .unwrap_or(DataValue::Null))
1104                    } else {
1105                        Ok(context
1106                            .get_partition_count(row_index, Some(&column))
1107                            .unwrap_or(DataValue::Null))
1108                    }
1109                }
1110            }
1111            _ => Err(anyhow!("Unknown window function: {}", name)),
1112        }
1113    }
1114
1115    /// Evaluate a method call on a column (e.g., `column.Trim()`)
1116    fn evaluate_method_call(
1117        &mut self,
1118        object: &str,
1119        method: &str,
1120        args: &[SqlExpression],
1121        row_index: usize,
1122    ) -> Result<DataValue> {
1123        // Get column value
1124        let col_index = self.table.get_column_index(object).ok_or_else(|| {
1125            let suggestion = self.find_similar_column(object);
1126            match suggestion {
1127                Some(similar) => {
1128                    anyhow!("Column '{}' not found. Did you mean '{}'?", object, similar)
1129                }
1130                None => anyhow!("Column '{}' not found", object),
1131            }
1132        })?;
1133
1134        let cell_value = self.table.get_value(row_index, col_index).cloned();
1135
1136        self.evaluate_method_on_value(
1137            &cell_value.unwrap_or(DataValue::Null),
1138            method,
1139            args,
1140            row_index,
1141        )
1142    }
1143
1144    /// Evaluate a method on a value
1145    fn evaluate_method_on_value(
1146        &mut self,
1147        value: &DataValue,
1148        method: &str,
1149        args: &[SqlExpression],
1150        row_index: usize,
1151    ) -> Result<DataValue> {
1152        // First, try to proxy the method through the function registry
1153        // Many string methods have corresponding functions (TRIM, LENGTH, CONTAINS, etc.)
1154
1155        // Map method names to function names (case-insensitive matching)
1156        let function_name = match method.to_lowercase().as_str() {
1157            "trim" => "TRIM",
1158            "trimstart" | "trimbegin" => "TRIMSTART",
1159            "trimend" => "TRIMEND",
1160            "length" | "len" => "LENGTH",
1161            "contains" => "CONTAINS",
1162            "startswith" => "STARTSWITH",
1163            "endswith" => "ENDSWITH",
1164            "indexof" => "INDEXOF",
1165            _ => method, // Try the method name as-is
1166        };
1167
1168        // Check if we have this function in the registry
1169        if self.function_registry.get(function_name).is_some() {
1170            debug!(
1171                "Proxying method '{}' through function registry as '{}'",
1172                method, function_name
1173            );
1174
1175            // Prepare arguments: receiver is the first argument, followed by method args
1176            let mut func_args = vec![value.clone()];
1177
1178            // Evaluate method arguments and add them
1179            for arg in args {
1180                func_args.push(self.evaluate(arg, row_index)?);
1181            }
1182
1183            // Get the function and call it
1184            let func = self.function_registry.get(function_name).unwrap();
1185            return func.evaluate(&func_args);
1186        }
1187
1188        // If not in registry, the method is not supported
1189        // All methods should be registered in the function registry
1190        Err(anyhow!(
1191            "Method '{}' not found. It should be registered in the function registry.",
1192            method
1193        ))
1194    }
1195
1196    /// Evaluate a CASE expression
1197    fn evaluate_case_expression(
1198        &mut self,
1199        when_branches: &[crate::sql::recursive_parser::WhenBranch],
1200        else_branch: &Option<Box<SqlExpression>>,
1201        row_index: usize,
1202    ) -> Result<DataValue> {
1203        debug!(
1204            "ArithmeticEvaluator: evaluating CASE expression for row {}",
1205            row_index
1206        );
1207
1208        // Evaluate each WHEN condition in order
1209        for branch in when_branches {
1210            // Evaluate the condition as a boolean
1211            let condition_result = self.evaluate_condition_as_bool(&branch.condition, row_index)?;
1212
1213            if condition_result {
1214                debug!("CASE: WHEN condition matched, evaluating result expression");
1215                return self.evaluate(&branch.result, row_index);
1216            }
1217        }
1218
1219        // If no WHEN condition matched, evaluate ELSE clause (or return NULL)
1220        if let Some(else_expr) = else_branch {
1221            debug!("CASE: No WHEN matched, evaluating ELSE expression");
1222            self.evaluate(else_expr, row_index)
1223        } else {
1224            debug!("CASE: No WHEN matched and no ELSE, returning NULL");
1225            Ok(DataValue::Null)
1226        }
1227    }
1228
1229    /// Evaluate a simple CASE expression
1230    fn evaluate_simple_case_expression(
1231        &mut self,
1232        expr: &Box<SqlExpression>,
1233        when_branches: &[crate::sql::parser::ast::SimpleWhenBranch],
1234        else_branch: &Option<Box<SqlExpression>>,
1235        row_index: usize,
1236    ) -> Result<DataValue> {
1237        debug!(
1238            "ArithmeticEvaluator: evaluating simple CASE expression for row {}",
1239            row_index
1240        );
1241
1242        // Evaluate the main expression once
1243        let case_value = self.evaluate(expr, row_index)?;
1244        debug!("Simple CASE: evaluated expression to {:?}", case_value);
1245
1246        // Compare against each WHEN value in order
1247        for branch in when_branches {
1248            // Evaluate the WHEN value
1249            let when_value = self.evaluate(&branch.value, row_index)?;
1250
1251            // Check for equality
1252            if self.values_equal(&case_value, &when_value)? {
1253                debug!("Simple CASE: WHEN value matched, evaluating result expression");
1254                return self.evaluate(&branch.result, row_index);
1255            }
1256        }
1257
1258        // If no WHEN value matched, evaluate ELSE clause (or return NULL)
1259        if let Some(else_expr) = else_branch {
1260            debug!("Simple CASE: No WHEN matched, evaluating ELSE expression");
1261            self.evaluate(else_expr, row_index)
1262        } else {
1263            debug!("Simple CASE: No WHEN matched and no ELSE, returning NULL");
1264            Ok(DataValue::Null)
1265        }
1266    }
1267
1268    /// Check if two DataValues are equal
1269    fn values_equal(&self, left: &DataValue, right: &DataValue) -> Result<bool> {
1270        match (left, right) {
1271            (DataValue::Null, DataValue::Null) => Ok(true),
1272            (DataValue::Null, _) | (_, DataValue::Null) => Ok(false),
1273            (DataValue::Integer(a), DataValue::Integer(b)) => Ok(a == b),
1274            (DataValue::Float(a), DataValue::Float(b)) => Ok((a - b).abs() < f64::EPSILON),
1275            (DataValue::String(a), DataValue::String(b)) => Ok(a == b),
1276            (DataValue::Boolean(a), DataValue::Boolean(b)) => Ok(a == b),
1277            (DataValue::DateTime(a), DataValue::DateTime(b)) => Ok(a == b),
1278            // Type coercion for numeric comparisons
1279            (DataValue::Integer(a), DataValue::Float(b)) => {
1280                Ok((*a as f64 - b).abs() < f64::EPSILON)
1281            }
1282            (DataValue::Float(a), DataValue::Integer(b)) => {
1283                Ok((a - *b as f64).abs() < f64::EPSILON)
1284            }
1285            _ => Ok(false),
1286        }
1287    }
1288
1289    /// Helper method to evaluate an expression as a boolean (for CASE WHEN conditions)
1290    fn evaluate_condition_as_bool(
1291        &mut self,
1292        expr: &SqlExpression,
1293        row_index: usize,
1294    ) -> Result<bool> {
1295        let value = self.evaluate(expr, row_index)?;
1296
1297        match value {
1298            DataValue::Boolean(b) => Ok(b),
1299            DataValue::Integer(i) => Ok(i != 0),
1300            DataValue::Float(f) => Ok(f != 0.0),
1301            DataValue::Null => Ok(false),
1302            DataValue::String(s) => Ok(!s.is_empty()),
1303            DataValue::InternedString(s) => Ok(!s.is_empty()),
1304            _ => Ok(true), // Other types are considered truthy
1305        }
1306    }
1307}
1308
1309#[cfg(test)]
1310mod tests {
1311    use super::*;
1312    use crate::data::datatable::{DataColumn, DataRow};
1313
1314    fn create_test_table() -> DataTable {
1315        let mut table = DataTable::new("test");
1316        table.add_column(DataColumn::new("a"));
1317        table.add_column(DataColumn::new("b"));
1318        table.add_column(DataColumn::new("c"));
1319
1320        table
1321            .add_row(DataRow::new(vec![
1322                DataValue::Integer(10),
1323                DataValue::Float(2.5),
1324                DataValue::Integer(4),
1325            ]))
1326            .unwrap();
1327
1328        table
1329    }
1330
1331    #[test]
1332    fn test_evaluate_column() {
1333        let table = create_test_table();
1334        let mut evaluator = ArithmeticEvaluator::new(&table);
1335
1336        let expr = SqlExpression::Column("a".to_string());
1337        let result = evaluator.evaluate(&expr, 0).unwrap();
1338        assert_eq!(result, DataValue::Integer(10));
1339    }
1340
1341    #[test]
1342    fn test_evaluate_number_literal() {
1343        let table = create_test_table();
1344        let mut evaluator = ArithmeticEvaluator::new(&table);
1345
1346        let expr = SqlExpression::NumberLiteral("42".to_string());
1347        let result = evaluator.evaluate(&expr, 0).unwrap();
1348        assert_eq!(result, DataValue::Integer(42));
1349
1350        let expr = SqlExpression::NumberLiteral("3.14".to_string());
1351        let result = evaluator.evaluate(&expr, 0).unwrap();
1352        assert_eq!(result, DataValue::Float(3.14));
1353    }
1354
1355    #[test]
1356    fn test_add_values() {
1357        let table = create_test_table();
1358        let mut evaluator = ArithmeticEvaluator::new(&table);
1359
1360        // Integer + Integer
1361        let result = evaluator
1362            .add_values(&DataValue::Integer(5), &DataValue::Integer(3))
1363            .unwrap();
1364        assert_eq!(result, DataValue::Integer(8));
1365
1366        // Integer + Float
1367        let result = evaluator
1368            .add_values(&DataValue::Integer(5), &DataValue::Float(2.5))
1369            .unwrap();
1370        assert_eq!(result, DataValue::Float(7.5));
1371    }
1372
1373    #[test]
1374    fn test_multiply_values() {
1375        let table = create_test_table();
1376        let mut evaluator = ArithmeticEvaluator::new(&table);
1377
1378        // Integer * Float
1379        let result = evaluator
1380            .multiply_values(&DataValue::Integer(4), &DataValue::Float(2.5))
1381            .unwrap();
1382        assert_eq!(result, DataValue::Float(10.0));
1383    }
1384
1385    #[test]
1386    fn test_divide_values() {
1387        let table = create_test_table();
1388        let mut evaluator = ArithmeticEvaluator::new(&table);
1389
1390        // Exact division
1391        let result = evaluator
1392            .divide_values(&DataValue::Integer(10), &DataValue::Integer(2))
1393            .unwrap();
1394        assert_eq!(result, DataValue::Integer(5));
1395
1396        // Non-exact division
1397        let result = evaluator
1398            .divide_values(&DataValue::Integer(10), &DataValue::Integer(3))
1399            .unwrap();
1400        assert_eq!(result, DataValue::Float(10.0 / 3.0));
1401    }
1402
1403    #[test]
1404    fn test_division_by_zero() {
1405        let table = create_test_table();
1406        let mut evaluator = ArithmeticEvaluator::new(&table);
1407
1408        let result = evaluator.divide_values(&DataValue::Integer(10), &DataValue::Integer(0));
1409        assert!(result.is_err());
1410        assert!(result.unwrap_err().to_string().contains("Division by zero"));
1411    }
1412
1413    #[test]
1414    fn test_binary_op_expression() {
1415        let table = create_test_table();
1416        let mut evaluator = ArithmeticEvaluator::new(&table);
1417
1418        // a * b where a=10, b=2.5
1419        let expr = SqlExpression::BinaryOp {
1420            left: Box::new(SqlExpression::Column("a".to_string())),
1421            op: "*".to_string(),
1422            right: Box::new(SqlExpression::Column("b".to_string())),
1423        };
1424
1425        let result = evaluator.evaluate(&expr, 0).unwrap();
1426        assert_eq!(result, DataValue::Float(25.0));
1427    }
1428}