sql_cli/data/
arithmetic_evaluator.rs

1use crate::config::global::get_date_notation;
2use crate::data::data_view::DataView;
3use crate::data::datatable::{DataTable, DataValue};
4use crate::data::value_comparisons::compare_with_op;
5use crate::sql::aggregate_functions::AggregateFunctionRegistry; // New registry
6use crate::sql::aggregates::AggregateRegistry; // Old registry (for migration)
7use crate::sql::functions::FunctionRegistry;
8use crate::sql::parser::ast::{ColumnRef, WindowSpec};
9use crate::sql::recursive_parser::SqlExpression;
10use crate::sql::window_context::WindowContext;
11use crate::sql::window_functions::{ExpressionEvaluator, WindowFunctionRegistry};
12use anyhow::{anyhow, Result};
13use std::collections::{HashMap, HashSet};
14use std::sync::Arc;
15use tracing::debug;
16
17/// Evaluates SQL expressions to compute `DataValues` (for SELECT clauses)
18/// This is different from `RecursiveWhereEvaluator` which returns boolean
19pub struct ArithmeticEvaluator<'a> {
20    table: &'a DataTable,
21    date_notation: String,
22    function_registry: Arc<FunctionRegistry>,
23    aggregate_registry: Arc<AggregateRegistry>, // Old registry (being phased out)
24    new_aggregate_registry: Arc<AggregateFunctionRegistry>, // New registry
25    window_function_registry: Arc<WindowFunctionRegistry>,
26    visible_rows: Option<Vec<usize>>, // For aggregate functions on filtered views
27    window_contexts: HashMap<String, Arc<WindowContext>>, // Cache window contexts by spec
28    table_aliases: HashMap<String, String>, // Map alias -> table name for qualified columns
29}
30
31impl<'a> ArithmeticEvaluator<'a> {
32    #[must_use]
33    pub fn new(table: &'a DataTable) -> Self {
34        Self {
35            table,
36            date_notation: get_date_notation(),
37            function_registry: Arc::new(FunctionRegistry::new()),
38            aggregate_registry: Arc::new(AggregateRegistry::new()),
39            new_aggregate_registry: Arc::new(AggregateFunctionRegistry::new()),
40            window_function_registry: Arc::new(WindowFunctionRegistry::new()),
41            visible_rows: None,
42            window_contexts: HashMap::new(),
43            table_aliases: HashMap::new(),
44        }
45    }
46
47    #[must_use]
48    pub fn with_date_notation(table: &'a DataTable, date_notation: String) -> Self {
49        Self {
50            table,
51            date_notation,
52            function_registry: Arc::new(FunctionRegistry::new()),
53            aggregate_registry: Arc::new(AggregateRegistry::new()),
54            new_aggregate_registry: Arc::new(AggregateFunctionRegistry::new()),
55            window_function_registry: Arc::new(WindowFunctionRegistry::new()),
56            visible_rows: None,
57            window_contexts: HashMap::new(),
58            table_aliases: HashMap::new(),
59        }
60    }
61
62    /// Set visible rows for aggregate functions (for filtered views)
63    #[must_use]
64    pub fn with_visible_rows(mut self, rows: Vec<usize>) -> Self {
65        self.visible_rows = Some(rows);
66        self
67    }
68
69    /// Set table aliases for qualified column resolution
70    #[must_use]
71    pub fn with_table_aliases(mut self, aliases: HashMap<String, String>) -> Self {
72        self.table_aliases = aliases;
73        self
74    }
75
76    #[must_use]
77    pub fn with_date_notation_and_registry(
78        table: &'a DataTable,
79        date_notation: String,
80        function_registry: Arc<FunctionRegistry>,
81    ) -> Self {
82        Self {
83            table,
84            date_notation,
85            function_registry,
86            aggregate_registry: Arc::new(AggregateRegistry::new()),
87            new_aggregate_registry: Arc::new(AggregateFunctionRegistry::new()),
88            window_function_registry: Arc::new(WindowFunctionRegistry::new()),
89            visible_rows: None,
90            window_contexts: HashMap::new(),
91            table_aliases: HashMap::new(),
92        }
93    }
94
95    /// Find a column name similar to the given name using edit distance
96    fn find_similar_column(&self, name: &str) -> Option<String> {
97        let columns = self.table.column_names();
98        let mut best_match: Option<(String, usize)> = None;
99
100        for col in columns {
101            let distance = self.edit_distance(&col.to_lowercase(), &name.to_lowercase());
102            // Only suggest if distance is small (likely a typo)
103            // Allow up to 3 edits for longer names
104            let max_distance = if name.len() > 10 { 3 } else { 2 };
105            if distance <= max_distance {
106                match &best_match {
107                    None => best_match = Some((col, distance)),
108                    Some((_, best_dist)) if distance < *best_dist => {
109                        best_match = Some((col, distance));
110                    }
111                    _ => {}
112                }
113            }
114        }
115
116        best_match.map(|(name, _)| name)
117    }
118
119    /// Calculate Levenshtein edit distance between two strings
120    fn edit_distance(&self, s1: &str, s2: &str) -> usize {
121        // Use the shared implementation from string_methods
122        crate::sql::functions::string_methods::EditDistanceFunction::calculate_edit_distance(s1, s2)
123    }
124
125    /// Evaluate an SQL expression to produce a `DataValue`
126    pub fn evaluate(&mut self, expr: &SqlExpression, row_index: usize) -> Result<DataValue> {
127        debug!(
128            "ArithmeticEvaluator: evaluating {:?} for row {}",
129            expr, row_index
130        );
131
132        match expr {
133            SqlExpression::Column(column_ref) => self.evaluate_column_ref(column_ref, row_index),
134            SqlExpression::StringLiteral(s) => Ok(DataValue::String(s.clone())),
135            SqlExpression::BooleanLiteral(b) => Ok(DataValue::Boolean(*b)),
136            SqlExpression::NumberLiteral(n) => self.evaluate_number_literal(n),
137            SqlExpression::Null => Ok(DataValue::Null),
138            SqlExpression::BinaryOp { left, op, right } => {
139                self.evaluate_binary_op(left, op, right, row_index)
140            }
141            SqlExpression::FunctionCall {
142                name,
143                args,
144                distinct,
145            } => self.evaluate_function_with_distinct(name, args, *distinct, row_index),
146            SqlExpression::WindowFunction {
147                name,
148                args,
149                window_spec,
150            } => self.evaluate_window_function(name, args, window_spec, row_index),
151            SqlExpression::MethodCall {
152                object,
153                method,
154                args,
155            } => self.evaluate_method_call(object, method, args, row_index),
156            SqlExpression::ChainedMethodCall { base, method, args } => {
157                // Evaluate the base expression first, then apply the method
158                let base_value = self.evaluate(base, row_index)?;
159                self.evaluate_method_on_value(&base_value, method, args, row_index)
160            }
161            SqlExpression::CaseExpression {
162                when_branches,
163                else_branch,
164            } => self.evaluate_case_expression(when_branches, else_branch, row_index),
165            SqlExpression::SimpleCaseExpression {
166                expr,
167                when_branches,
168                else_branch,
169            } => self.evaluate_simple_case_expression(expr, when_branches, else_branch, row_index),
170            _ => Err(anyhow!(
171                "Unsupported expression type for arithmetic evaluation: {:?}",
172                expr
173            )),
174        }
175    }
176
177    /// Evaluate a column reference with proper table scoping
178    fn evaluate_column_ref(&self, column_ref: &ColumnRef, row_index: usize) -> Result<DataValue> {
179        if let Some(table_prefix) = &column_ref.table_prefix {
180            // For qualified references, ONLY try qualified lookup - no fallback
181            let qualified_name = format!("{}.{}", table_prefix, column_ref.name);
182
183            if let Some(col_idx) = self.table.find_column_by_qualified_name(&qualified_name) {
184                return self
185                    .table
186                    .get_value(row_index, col_idx)
187                    .ok_or_else(|| anyhow!("Row {} out of bounds", row_index))
188                    .map(|v| v.clone());
189            }
190
191            // If not found, return error - don't fall back
192            Err(anyhow!("Column '{}' not found", qualified_name))
193        } else {
194            // Simple column name lookup
195            self.evaluate_column(&column_ref.name, row_index)
196        }
197    }
198
199    /// Evaluate a column reference
200    fn evaluate_column(&self, column_name: &str, row_index: usize) -> Result<DataValue> {
201        // First try to resolve qualified column names (table.column or alias.column)
202        let resolved_column = if column_name.contains('.') {
203            // Split on last dot to handle cases like "schema.table.column"
204            if let Some(dot_pos) = column_name.rfind('.') {
205                let _table_or_alias = &column_name[..dot_pos];
206                let col_name = &column_name[dot_pos + 1..];
207
208                // For now, just use the column name part
209                // In the future, we could validate the table/alias part
210                debug!(
211                    "Resolving qualified column: {} -> {}",
212                    column_name, col_name
213                );
214                col_name.to_string()
215            } else {
216                column_name.to_string()
217            }
218        } else {
219            column_name.to_string()
220        };
221
222        let col_index = if let Some(idx) = self.table.get_column_index(&resolved_column) {
223            idx
224        } else if resolved_column != column_name {
225            // If not found, try the original name
226            if let Some(idx) = self.table.get_column_index(column_name) {
227                idx
228            } else {
229                let suggestion = self.find_similar_column(&resolved_column);
230                return Err(match suggestion {
231                    Some(similar) => anyhow!(
232                        "Column '{}' not found. Did you mean '{}'?",
233                        column_name,
234                        similar
235                    ),
236                    None => anyhow!("Column '{}' not found", column_name),
237                });
238            }
239        } else {
240            let suggestion = self.find_similar_column(&resolved_column);
241            return Err(match suggestion {
242                Some(similar) => anyhow!(
243                    "Column '{}' not found. Did you mean '{}'?",
244                    column_name,
245                    similar
246                ),
247                None => anyhow!("Column '{}' not found", column_name),
248            });
249        };
250
251        if row_index >= self.table.row_count() {
252            return Err(anyhow!("Row index {} out of bounds", row_index));
253        }
254
255        let row = self
256            .table
257            .get_row(row_index)
258            .ok_or_else(|| anyhow!("Row {} not found", row_index))?;
259
260        let value = row
261            .get(col_index)
262            .ok_or_else(|| anyhow!("Column index {} out of bounds for row", col_index))?;
263
264        Ok(value.clone())
265    }
266
267    /// Evaluate a number literal (handles both integers and floats)
268    fn evaluate_number_literal(&self, number_str: &str) -> Result<DataValue> {
269        // Try to parse as integer first
270        if let Ok(int_val) = number_str.parse::<i64>() {
271            return Ok(DataValue::Integer(int_val));
272        }
273
274        // If that fails, try as float
275        if let Ok(float_val) = number_str.parse::<f64>() {
276            return Ok(DataValue::Float(float_val));
277        }
278
279        Err(anyhow!("Invalid number literal: {}", number_str))
280    }
281
282    /// Evaluate a binary operation (arithmetic)
283    fn evaluate_binary_op(
284        &mut self,
285        left: &SqlExpression,
286        op: &str,
287        right: &SqlExpression,
288        row_index: usize,
289    ) -> Result<DataValue> {
290        let left_val = self.evaluate(left, row_index)?;
291        let right_val = self.evaluate(right, row_index)?;
292
293        debug!(
294            "ArithmeticEvaluator: {} {} {}",
295            self.format_value(&left_val),
296            op,
297            self.format_value(&right_val)
298        );
299
300        match op {
301            "+" => self.add_values(&left_val, &right_val),
302            "-" => self.subtract_values(&left_val, &right_val),
303            "*" => self.multiply_values(&left_val, &right_val),
304            "/" => self.divide_values(&left_val, &right_val),
305            "%" => {
306                // Modulo operator - call MOD function
307                let args = vec![left.clone(), right.clone()];
308                self.evaluate_function("MOD", &args, row_index)
309            }
310            // Comparison operators (return boolean results)
311            // Use centralized comparison logic for consistency
312            ">" | "<" | ">=" | "<=" | "=" | "!=" | "<>" => {
313                let result = compare_with_op(&left_val, &right_val, op, false);
314                Ok(DataValue::Boolean(result))
315            }
316            // IS NULL / IS NOT NULL operators
317            "IS NULL" => Ok(DataValue::Boolean(matches!(left_val, DataValue::Null))),
318            "IS NOT NULL" => Ok(DataValue::Boolean(!matches!(left_val, DataValue::Null))),
319            // Logical operators
320            "AND" => {
321                let left_bool = self.to_bool(&left_val)?;
322                let right_bool = self.to_bool(&right_val)?;
323                Ok(DataValue::Boolean(left_bool && right_bool))
324            }
325            "OR" => {
326                let left_bool = self.to_bool(&left_val)?;
327                let right_bool = self.to_bool(&right_val)?;
328                Ok(DataValue::Boolean(left_bool || right_bool))
329            }
330            _ => Err(anyhow!("Unsupported arithmetic operator: {}", op)),
331        }
332    }
333
334    /// Add two `DataValues` with type coercion
335    fn add_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
336        // NULL handling - any operation with NULL returns NULL
337        if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
338            return Ok(DataValue::Null);
339        }
340
341        match (left, right) {
342            (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a + b)),
343            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 + b)),
344            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a + *b as f64)),
345            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a + b)),
346            _ => Err(anyhow!("Cannot add {:?} and {:?}", left, right)),
347        }
348    }
349
350    /// Subtract two `DataValues` with type coercion
351    fn subtract_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
352        // NULL handling - any operation with NULL returns NULL
353        if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
354            return Ok(DataValue::Null);
355        }
356
357        match (left, right) {
358            (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a - b)),
359            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 - b)),
360            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a - *b as f64)),
361            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a - b)),
362            _ => Err(anyhow!("Cannot subtract {:?} and {:?}", left, right)),
363        }
364    }
365
366    /// Multiply two `DataValues` with type coercion
367    fn multiply_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
368        // NULL handling - any operation with NULL returns NULL
369        if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
370            return Ok(DataValue::Null);
371        }
372
373        match (left, right) {
374            (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a * b)),
375            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 * b)),
376            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a * *b as f64)),
377            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a * b)),
378            _ => Err(anyhow!("Cannot multiply {:?} and {:?}", left, right)),
379        }
380    }
381
382    /// Divide two `DataValues` with type coercion
383    fn divide_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
384        // NULL handling - any operation with NULL returns NULL
385        if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
386            return Ok(DataValue::Null);
387        }
388
389        // Check for division by zero first
390        let is_zero = match right {
391            DataValue::Integer(0) => true,
392            DataValue::Float(f) if *f == 0.0 => true, // Only check for exact zero, not epsilon
393            _ => false,
394        };
395
396        if is_zero {
397            return Err(anyhow!("Division by zero"));
398        }
399
400        match (left, right) {
401            (DataValue::Integer(a), DataValue::Integer(b)) => {
402                // Integer division - if result is exact, keep as int, otherwise promote to float
403                if a % b == 0 {
404                    Ok(DataValue::Integer(a / b))
405                } else {
406                    Ok(DataValue::Float(*a as f64 / *b as f64))
407                }
408            }
409            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 / b)),
410            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a / *b as f64)),
411            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a / b)),
412            _ => Err(anyhow!("Cannot divide {:?} and {:?}", left, right)),
413        }
414    }
415
416    /// Format a `DataValue` for debug output
417    fn format_value(&self, value: &DataValue) -> String {
418        match value {
419            DataValue::Integer(i) => i.to_string(),
420            DataValue::Float(f) => f.to_string(),
421            DataValue::String(s) => format!("'{s}'"),
422            _ => format!("{value:?}"),
423        }
424    }
425
426    /// Convert a DataValue to boolean for logical operations
427    fn to_bool(&self, value: &DataValue) -> Result<bool> {
428        match value {
429            DataValue::Boolean(b) => Ok(*b),
430            DataValue::Integer(i) => Ok(*i != 0),
431            DataValue::Float(f) => Ok(*f != 0.0),
432            DataValue::Null => Ok(false),
433            _ => Err(anyhow!("Cannot convert {:?} to boolean", value)),
434        }
435    }
436
437    /// Evaluate a function call
438    fn evaluate_function_with_distinct(
439        &mut self,
440        name: &str,
441        args: &[SqlExpression],
442        distinct: bool,
443        row_index: usize,
444    ) -> Result<DataValue> {
445        // If DISTINCT is specified, handle it specially for aggregate functions
446        if distinct {
447            let name_upper = name.to_uppercase();
448
449            // Check if it's an aggregate function in either registry
450            if self.aggregate_registry.is_aggregate(&name_upper)
451                || self.new_aggregate_registry.contains(&name_upper)
452            {
453                return self.evaluate_aggregate_with_distinct(&name_upper, args, row_index);
454            } else {
455                return Err(anyhow!(
456                    "DISTINCT can only be used with aggregate functions"
457                ));
458            }
459        }
460
461        // Otherwise, use the regular evaluation
462        self.evaluate_function(name, args, row_index)
463    }
464
465    fn evaluate_aggregate_with_distinct(
466        &mut self,
467        name: &str,
468        args: &[SqlExpression],
469        _row_index: usize,
470    ) -> Result<DataValue> {
471        let name_upper = name.to_uppercase();
472
473        // Check new aggregate registry first for migrated functions
474        if self.new_aggregate_registry.get(&name_upper).is_some() {
475            let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
476                visible.clone()
477            } else {
478                (0..self.table.rows.len()).collect()
479            };
480
481            // Collect and deduplicate values for DISTINCT
482            let mut vals = Vec::new();
483            for &row_idx in &rows_to_process {
484                if !args.is_empty() {
485                    let value = self.evaluate(&args[0], row_idx)?;
486                    vals.push(value);
487                }
488            }
489
490            // Deduplicate values
491            let mut seen = HashSet::new();
492            let unique_values: Vec<_> = vals
493                .into_iter()
494                .filter(|v| {
495                    let key = format!("{:?}", v);
496                    seen.insert(key)
497                })
498                .collect();
499
500            // Get the aggregate function from the new registry
501            let agg_func = self.new_aggregate_registry.get(&name_upper).unwrap();
502            let mut state = agg_func.create_state();
503
504            // Use unique values
505            for value in &unique_values {
506                state.accumulate(value)?;
507            }
508
509            return Ok(state.finalize());
510        }
511
512        // Check old aggregate registry (DISTINCT handling)
513        if self.aggregate_registry.get(&name_upper).is_some() {
514            // Determine which rows to process first
515            let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
516                visible.clone()
517            } else {
518                (0..self.table.rows.len()).collect()
519            };
520
521            // Special handling for STRING_AGG with separator parameter
522            if name_upper == "STRING_AGG" && args.len() >= 2 {
523                // STRING_AGG(DISTINCT column, separator)
524                let mut state = crate::sql::aggregates::AggregateState::StringAgg(
525                    // Evaluate the separator (second argument) once
526                    if args.len() >= 2 {
527                        let separator = self.evaluate(&args[1], 0)?; // Separator doesn't depend on row
528                        match separator {
529                            DataValue::String(s) => crate::sql::aggregates::StringAggState::new(&s),
530                            DataValue::InternedString(s) => {
531                                crate::sql::aggregates::StringAggState::new(&s)
532                            }
533                            _ => crate::sql::aggregates::StringAggState::new(","), // Default separator
534                        }
535                    } else {
536                        crate::sql::aggregates::StringAggState::new(",")
537                    },
538                );
539
540                // Evaluate the first argument (column) for each row and accumulate
541                // Handle DISTINCT - use a HashSet to track seen values
542                let mut seen_values = HashSet::new();
543
544                for &row_idx in &rows_to_process {
545                    let value = self.evaluate(&args[0], row_idx)?;
546
547                    // Skip if we've seen this value
548                    if !seen_values.insert(value.clone()) {
549                        continue; // Skip duplicate values
550                    }
551
552                    // Now get the aggregate function and accumulate
553                    let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
554                    agg_func.accumulate(&mut state, &value)?;
555                }
556
557                // Finalize the aggregate
558                let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
559                return Ok(agg_func.finalize(state));
560            }
561
562            // For other aggregates with DISTINCT
563            // Evaluate the argument expression for each row
564            let mut vals = Vec::new();
565            for &row_idx in &rows_to_process {
566                if !args.is_empty() {
567                    let value = self.evaluate(&args[0], row_idx)?;
568                    vals.push(value);
569                }
570            }
571
572            // Deduplicate values for DISTINCT
573            let mut seen = HashSet::new();
574            let mut unique_values = Vec::new();
575            for value in vals {
576                if seen.insert(value.clone()) {
577                    unique_values.push(value);
578                }
579            }
580
581            // Now get the aggregate function and process
582            let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
583            let mut state = agg_func.init();
584
585            // Use unique values
586            for value in &unique_values {
587                agg_func.accumulate(&mut state, value)?;
588            }
589
590            return Ok(agg_func.finalize(state));
591        }
592
593        Err(anyhow!("Unknown aggregate function: {}", name))
594    }
595
596    fn evaluate_function(
597        &mut self,
598        name: &str,
599        args: &[SqlExpression],
600        row_index: usize,
601    ) -> Result<DataValue> {
602        // Check if this is an aggregate function
603        let name_upper = name.to_uppercase();
604
605        // Check new aggregate registry first (for migrated functions)
606        if self.new_aggregate_registry.get(&name_upper).is_some() {
607            // Use new registry for SUM
608            let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
609                visible.clone()
610            } else {
611                (0..self.table.rows.len()).collect()
612            };
613
614            // Get the aggregate function from the new registry
615            let agg_func = self.new_aggregate_registry.get(&name_upper).unwrap();
616            let mut state = agg_func.create_state();
617
618            // Special handling for COUNT(*)
619            if name_upper == "COUNT" || name_upper == "COUNT_STAR" {
620                if args.is_empty()
621                    || (args.len() == 1
622                        && matches!(&args[0], SqlExpression::Column(col) if col.name == "*"))
623                    || (args.len() == 1
624                        && matches!(&args[0], SqlExpression::StringLiteral(s) if s == "*"))
625                {
626                    // COUNT(*) or COUNT_STAR - count all rows
627                    for _ in &rows_to_process {
628                        state.accumulate(&DataValue::Integer(1))?;
629                    }
630                } else {
631                    // COUNT(column) - count non-null values
632                    for &row_idx in &rows_to_process {
633                        let value = self.evaluate(&args[0], row_idx)?;
634                        state.accumulate(&value)?;
635                    }
636                }
637            } else {
638                // Other aggregates - evaluate arguments and accumulate
639                if !args.is_empty() {
640                    for &row_idx in &rows_to_process {
641                        let value = self.evaluate(&args[0], row_idx)?;
642                        state.accumulate(&value)?;
643                    }
644                }
645            }
646
647            return Ok(state.finalize());
648        }
649
650        // Check old aggregate registry (for non-migrated functions)
651        if self.aggregate_registry.get(&name_upper).is_some() {
652            // Determine which rows to process first
653            let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
654                visible.clone()
655            } else {
656                (0..self.table.rows.len()).collect()
657            };
658
659            // Special handling for STRING_AGG with separator parameter
660            if name_upper == "STRING_AGG" && args.len() >= 2 {
661                // STRING_AGG(column, separator) - without DISTINCT (handled separately)
662                let mut state = crate::sql::aggregates::AggregateState::StringAgg(
663                    // Evaluate the separator (second argument) once
664                    if args.len() >= 2 {
665                        let separator = self.evaluate(&args[1], 0)?; // Separator doesn't depend on row
666                        match separator {
667                            DataValue::String(s) => crate::sql::aggregates::StringAggState::new(&s),
668                            DataValue::InternedString(s) => {
669                                crate::sql::aggregates::StringAggState::new(&s)
670                            }
671                            _ => crate::sql::aggregates::StringAggState::new(","), // Default separator
672                        }
673                    } else {
674                        crate::sql::aggregates::StringAggState::new(",")
675                    },
676                );
677
678                // Evaluate the first argument (column) for each row and accumulate
679                for &row_idx in &rows_to_process {
680                    let value = self.evaluate(&args[0], row_idx)?;
681                    // Now get the aggregate function and accumulate
682                    let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
683                    agg_func.accumulate(&mut state, &value)?;
684                }
685
686                // Finalize the aggregate
687                let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
688                return Ok(agg_func.finalize(state));
689            }
690
691            // Evaluate arguments first if needed (to avoid borrow issues)
692            let values = if !args.is_empty()
693                && !(args.len() == 1
694                    && matches!(&args[0], SqlExpression::Column(c) if c.name == "*"))
695            {
696                // Evaluate the argument expression for each row
697                let mut vals = Vec::new();
698                for &row_idx in &rows_to_process {
699                    let value = self.evaluate(&args[0], row_idx)?;
700                    vals.push(value);
701                }
702                Some(vals)
703            } else {
704                None
705            };
706
707            // Now get the aggregate function and process
708            let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
709            let mut state = agg_func.init();
710
711            if let Some(values) = values {
712                // Use evaluated values (DISTINCT is handled in evaluate_aggregate_with_distinct)
713                for value in &values {
714                    agg_func.accumulate(&mut state, value)?;
715                }
716            } else {
717                // COUNT(*) case
718                for _ in &rows_to_process {
719                    agg_func.accumulate(&mut state, &DataValue::Integer(1))?;
720                }
721            }
722
723            return Ok(agg_func.finalize(state));
724        }
725
726        // First check if this function exists in the registry
727        if self.function_registry.get(name).is_some() {
728            // Evaluate all arguments first to avoid borrow issues
729            let mut evaluated_args = Vec::new();
730            for arg in args {
731                evaluated_args.push(self.evaluate(arg, row_index)?);
732            }
733
734            // Get the function and call it
735            let func = self.function_registry.get(name).unwrap();
736            return func.evaluate(&evaluated_args);
737        }
738
739        // If not in registry, return error for unknown function
740        Err(anyhow!("Unknown function: {}", name))
741    }
742
743    /// Get or create a WindowContext for the given specification
744    fn get_or_create_window_context(&mut self, spec: &WindowSpec) -> Result<Arc<WindowContext>> {
745        // Create a key for caching based on the spec
746        let key = format!("{:?}", spec);
747
748        if let Some(context) = self.window_contexts.get(&key) {
749            return Ok(Arc::clone(context));
750        }
751
752        // Create a DataView from the table (with visible rows if filtered)
753        let data_view = if let Some(ref _visible_rows) = self.visible_rows {
754            // Create a filtered view
755            let view = DataView::new(Arc::new(self.table.clone()));
756            // Apply filtering based on visible rows
757            // Note: This is a simplified approach - in production we'd need proper filtering
758            view
759        } else {
760            DataView::new(Arc::new(self.table.clone()))
761        };
762
763        // Create the WindowContext with the full spec (including frame)
764        let context = WindowContext::new_with_spec(Arc::new(data_view), spec.clone())?;
765
766        let context = Arc::new(context);
767        self.window_contexts.insert(key, Arc::clone(&context));
768        Ok(context)
769    }
770
771    /// Evaluate a window function
772    fn evaluate_window_function(
773        &mut self,
774        name: &str,
775        args: &[SqlExpression],
776        spec: &WindowSpec,
777        row_index: usize,
778    ) -> Result<DataValue> {
779        let name_upper = name.to_uppercase();
780
781        // First check if this is a syntactic sugar function in the registry
782        debug!("Looking for window function {} in registry", name_upper);
783        if let Some(window_fn_arc) = self.window_function_registry.get(&name_upper) {
784            debug!("Found window function {} in registry", name_upper);
785
786            // Dereference to get the actual window function
787            let window_fn = window_fn_arc.as_ref();
788
789            // Validate arguments
790            window_fn.validate_args(args)?;
791
792            // Transform the window spec based on the function's requirements
793            let transformed_spec = window_fn.transform_window_spec(spec, args)?;
794
795            // Get or create the window context with the transformed spec
796            let context = self.get_or_create_window_context(&transformed_spec)?;
797
798            // Create an expression evaluator adapter
799            struct EvaluatorAdapter<'a, 'b> {
800                evaluator: &'a mut ArithmeticEvaluator<'b>,
801                row_index: usize,
802            }
803
804            impl<'a, 'b> ExpressionEvaluator for EvaluatorAdapter<'a, 'b> {
805                fn evaluate(
806                    &mut self,
807                    expr: &SqlExpression,
808                    _row_index: usize,
809                ) -> Result<DataValue> {
810                    self.evaluator.evaluate(expr, self.row_index)
811                }
812            }
813
814            let mut adapter = EvaluatorAdapter {
815                evaluator: self,
816                row_index,
817            };
818
819            // Call the window function's compute method
820            return window_fn.compute(&context, row_index, args, &mut adapter);
821        }
822
823        // Fall back to built-in window functions
824        let context = self.get_or_create_window_context(spec)?;
825
826        match name_upper.as_str() {
827            "LAG" => {
828                // LAG(column, offset, default)
829                if args.is_empty() {
830                    return Err(anyhow!("LAG requires at least 1 argument"));
831                }
832
833                // Get column name
834                let column = match &args[0] {
835                    SqlExpression::Column(col) => col.clone(),
836                    _ => return Err(anyhow!("LAG first argument must be a column")),
837                };
838
839                // Get offset (default 1)
840                let offset = if args.len() > 1 {
841                    match self.evaluate(&args[1], row_index)? {
842                        DataValue::Integer(i) => i as i32,
843                        _ => return Err(anyhow!("LAG offset must be an integer")),
844                    }
845                } else {
846                    1
847                };
848
849                // Get value at offset
850                Ok(context
851                    .get_offset_value(row_index, -offset, &column.name)
852                    .unwrap_or(DataValue::Null))
853            }
854            "LEAD" => {
855                // LEAD(column, offset, default)
856                if args.is_empty() {
857                    return Err(anyhow!("LEAD requires at least 1 argument"));
858                }
859
860                // Get column name
861                let column = match &args[0] {
862                    SqlExpression::Column(col) => col.clone(),
863                    _ => return Err(anyhow!("LEAD first argument must be a column")),
864                };
865
866                // Get offset (default 1)
867                let offset = if args.len() > 1 {
868                    match self.evaluate(&args[1], row_index)? {
869                        DataValue::Integer(i) => i as i32,
870                        _ => return Err(anyhow!("LEAD offset must be an integer")),
871                    }
872                } else {
873                    1
874                };
875
876                // Get value at offset
877                Ok(context
878                    .get_offset_value(row_index, offset, &column.name)
879                    .unwrap_or(DataValue::Null))
880            }
881            "ROW_NUMBER" => {
882                // ROW_NUMBER() - no arguments
883                Ok(DataValue::Integer(context.get_row_number(row_index) as i64))
884            }
885            "FIRST_VALUE" => {
886                // FIRST_VALUE(column) OVER (... ROWS ...)
887                if args.is_empty() {
888                    return Err(anyhow!("FIRST_VALUE requires 1 argument"));
889                }
890
891                let column = match &args[0] {
892                    SqlExpression::Column(col) => col.clone(),
893                    _ => return Err(anyhow!("FIRST_VALUE argument must be a column")),
894                };
895
896                // Use frame-aware version if frame is specified
897                if context.has_frame() {
898                    Ok(context
899                        .get_frame_first_value(row_index, &column.name)
900                        .unwrap_or(DataValue::Null))
901                } else {
902                    Ok(context
903                        .get_first_value(row_index, &column.name)
904                        .unwrap_or(DataValue::Null))
905                }
906            }
907            "LAST_VALUE" => {
908                // LAST_VALUE(column) OVER (... ROWS ...)
909                if args.is_empty() {
910                    return Err(anyhow!("LAST_VALUE requires 1 argument"));
911                }
912
913                let column = match &args[0] {
914                    SqlExpression::Column(col) => col.clone(),
915                    _ => return Err(anyhow!("LAST_VALUE argument must be a column")),
916                };
917
918                // Use frame-aware version if frame is specified
919                if context.has_frame() {
920                    Ok(context
921                        .get_frame_last_value(row_index, &column.name)
922                        .unwrap_or(DataValue::Null))
923                } else {
924                    Ok(context
925                        .get_last_value(row_index, &column.name)
926                        .unwrap_or(DataValue::Null))
927                }
928            }
929            "SUM" => {
930                // SUM(column) OVER (PARTITION BY ... ROWS n PRECEDING)
931                if args.is_empty() {
932                    return Err(anyhow!("SUM requires 1 argument"));
933                }
934
935                let column = match &args[0] {
936                    SqlExpression::Column(col) => col.clone(),
937                    _ => return Err(anyhow!("SUM argument must be a column")),
938                };
939
940                // Use frame-aware sum if frame is specified, otherwise use partition sum
941                if context.has_frame() {
942                    Ok(context
943                        .get_frame_sum(row_index, &column.name)
944                        .unwrap_or(DataValue::Null))
945                } else {
946                    Ok(context
947                        .get_partition_sum(row_index, &column.name)
948                        .unwrap_or(DataValue::Null))
949                }
950            }
951            "AVG" => {
952                // AVG(column) OVER (PARTITION BY ... ROWS n PRECEDING)
953                if args.is_empty() {
954                    return Err(anyhow!("AVG requires 1 argument"));
955                }
956
957                let column = match &args[0] {
958                    SqlExpression::Column(col) => col.clone(),
959                    _ => return Err(anyhow!("AVG argument must be a column")),
960                };
961
962                Ok(context
963                    .get_frame_avg(row_index, &column.name)
964                    .unwrap_or(DataValue::Null))
965            }
966            "STDDEV" | "STDEV" => {
967                // STDDEV(column) OVER (PARTITION BY ... ROWS n PRECEDING)
968                if args.is_empty() {
969                    return Err(anyhow!("STDDEV requires 1 argument"));
970                }
971
972                let column = match &args[0] {
973                    SqlExpression::Column(col) => col.clone(),
974                    _ => return Err(anyhow!("STDDEV argument must be a column")),
975                };
976
977                Ok(context
978                    .get_frame_stddev(row_index, &column.name)
979                    .unwrap_or(DataValue::Null))
980            }
981            "VARIANCE" | "VAR" => {
982                // VARIANCE(column) OVER (PARTITION BY ... ROWS n PRECEDING)
983                if args.is_empty() {
984                    return Err(anyhow!("VARIANCE requires 1 argument"));
985                }
986
987                let column = match &args[0] {
988                    SqlExpression::Column(col) => col.clone(),
989                    _ => return Err(anyhow!("VARIANCE argument must be a column")),
990                };
991
992                Ok(context
993                    .get_frame_variance(row_index, &column.name)
994                    .unwrap_or(DataValue::Null))
995            }
996            "MIN" => {
997                // MIN(column) OVER (PARTITION BY ... ROWS n PRECEDING)
998                if args.is_empty() {
999                    return Err(anyhow!("MIN requires 1 argument"));
1000                }
1001
1002                let column = match &args[0] {
1003                    SqlExpression::Column(col) => col.clone(),
1004                    _ => return Err(anyhow!("MIN argument must be a column")),
1005                };
1006
1007                let frame_rows = context.get_frame_rows(row_index);
1008                if frame_rows.is_empty() {
1009                    return Ok(DataValue::Null);
1010                }
1011
1012                let source_table = context.source();
1013                let col_idx = source_table
1014                    .get_column_index(&column.name)
1015                    .ok_or_else(|| anyhow!("Column '{}' not found", column.name))?;
1016
1017                let mut min_value: Option<DataValue> = None;
1018                for &row_idx in &frame_rows {
1019                    if let Some(value) = source_table.get_value(row_idx, col_idx) {
1020                        if !matches!(value, DataValue::Null) {
1021                            match &min_value {
1022                                None => min_value = Some(value.clone()),
1023                                Some(current_min) => {
1024                                    if value < current_min {
1025                                        min_value = Some(value.clone());
1026                                    }
1027                                }
1028                            }
1029                        }
1030                    }
1031                }
1032
1033                Ok(min_value.unwrap_or(DataValue::Null))
1034            }
1035            "MAX" => {
1036                // MAX(column) OVER (PARTITION BY ... ROWS n PRECEDING)
1037                if args.is_empty() {
1038                    return Err(anyhow!("MAX requires 1 argument"));
1039                }
1040
1041                let column = match &args[0] {
1042                    SqlExpression::Column(col) => col.clone(),
1043                    _ => return Err(anyhow!("MAX argument must be a column")),
1044                };
1045
1046                let frame_rows = context.get_frame_rows(row_index);
1047                if frame_rows.is_empty() {
1048                    return Ok(DataValue::Null);
1049                }
1050
1051                let source_table = context.source();
1052                let col_idx = source_table
1053                    .get_column_index(&column.name)
1054                    .ok_or_else(|| anyhow!("Column '{}' not found", column.name))?;
1055
1056                let mut max_value: Option<DataValue> = None;
1057                for &row_idx in &frame_rows {
1058                    if let Some(value) = source_table.get_value(row_idx, col_idx) {
1059                        if !matches!(value, DataValue::Null) {
1060                            match &max_value {
1061                                None => max_value = Some(value.clone()),
1062                                Some(current_max) => {
1063                                    if value > current_max {
1064                                        max_value = Some(value.clone());
1065                                    }
1066                                }
1067                            }
1068                        }
1069                    }
1070                }
1071
1072                Ok(max_value.unwrap_or(DataValue::Null))
1073            }
1074            "COUNT" => {
1075                // COUNT(*) or COUNT(column) OVER (PARTITION BY ... ROWS n PRECEDING)
1076                // Use frame-aware count if frame is specified, otherwise use partition count
1077
1078                if args.is_empty() {
1079                    // COUNT(*) OVER (...)
1080                    if context.has_frame() {
1081                        Ok(context
1082                            .get_frame_count(row_index, None)
1083                            .unwrap_or(DataValue::Null))
1084                    } else {
1085                        Ok(context
1086                            .get_partition_count(row_index, None)
1087                            .unwrap_or(DataValue::Null))
1088                    }
1089                } else {
1090                    // Check for COUNT(*)
1091                    let column = match &args[0] {
1092                        SqlExpression::Column(col) => {
1093                            if col.name == "*" {
1094                                // COUNT(*) - count all rows
1095                                if context.has_frame() {
1096                                    return Ok(context
1097                                        .get_frame_count(row_index, None)
1098                                        .unwrap_or(DataValue::Null));
1099                                } else {
1100                                    return Ok(context
1101                                        .get_partition_count(row_index, None)
1102                                        .unwrap_or(DataValue::Null));
1103                                }
1104                            }
1105                            col.clone()
1106                        }
1107                        SqlExpression::StringLiteral(s) if s == "*" => {
1108                            // COUNT(*) as StringLiteral
1109                            if context.has_frame() {
1110                                return Ok(context
1111                                    .get_frame_count(row_index, None)
1112                                    .unwrap_or(DataValue::Null));
1113                            } else {
1114                                return Ok(context
1115                                    .get_partition_count(row_index, None)
1116                                    .unwrap_or(DataValue::Null));
1117                            }
1118                        }
1119                        _ => return Err(anyhow!("COUNT argument must be a column or *")),
1120                    };
1121
1122                    // COUNT(column) - count non-null values
1123                    if context.has_frame() {
1124                        Ok(context
1125                            .get_frame_count(row_index, Some(&column.name))
1126                            .unwrap_or(DataValue::Null))
1127                    } else {
1128                        Ok(context
1129                            .get_partition_count(row_index, Some(&column.name))
1130                            .unwrap_or(DataValue::Null))
1131                    }
1132                }
1133            }
1134            _ => Err(anyhow!("Unknown window function: {}", name)),
1135        }
1136    }
1137
1138    /// Evaluate a method call on a column (e.g., `column.Trim()`)
1139    fn evaluate_method_call(
1140        &mut self,
1141        object: &str,
1142        method: &str,
1143        args: &[SqlExpression],
1144        row_index: usize,
1145    ) -> Result<DataValue> {
1146        // Get column value
1147        let col_index = self.table.get_column_index(object).ok_or_else(|| {
1148            let suggestion = self.find_similar_column(object);
1149            match suggestion {
1150                Some(similar) => {
1151                    anyhow!("Column '{}' not found. Did you mean '{}'?", object, similar)
1152                }
1153                None => anyhow!("Column '{}' not found", object),
1154            }
1155        })?;
1156
1157        let cell_value = self.table.get_value(row_index, col_index).cloned();
1158
1159        self.evaluate_method_on_value(
1160            &cell_value.unwrap_or(DataValue::Null),
1161            method,
1162            args,
1163            row_index,
1164        )
1165    }
1166
1167    /// Evaluate a method on a value
1168    fn evaluate_method_on_value(
1169        &mut self,
1170        value: &DataValue,
1171        method: &str,
1172        args: &[SqlExpression],
1173        row_index: usize,
1174    ) -> Result<DataValue> {
1175        // First, try to proxy the method through the function registry
1176        // Many string methods have corresponding functions (TRIM, LENGTH, CONTAINS, etc.)
1177
1178        // Map method names to function names (case-insensitive matching)
1179        let function_name = match method.to_lowercase().as_str() {
1180            "trim" => "TRIM",
1181            "trimstart" | "trimbegin" => "TRIMSTART",
1182            "trimend" => "TRIMEND",
1183            "length" | "len" => "LENGTH",
1184            "contains" => "CONTAINS",
1185            "startswith" => "STARTSWITH",
1186            "endswith" => "ENDSWITH",
1187            "indexof" => "INDEXOF",
1188            _ => method, // Try the method name as-is
1189        };
1190
1191        // Check if we have this function in the registry
1192        if self.function_registry.get(function_name).is_some() {
1193            debug!(
1194                "Proxying method '{}' through function registry as '{}'",
1195                method, function_name
1196            );
1197
1198            // Prepare arguments: receiver is the first argument, followed by method args
1199            let mut func_args = vec![value.clone()];
1200
1201            // Evaluate method arguments and add them
1202            for arg in args {
1203                func_args.push(self.evaluate(arg, row_index)?);
1204            }
1205
1206            // Get the function and call it
1207            let func = self.function_registry.get(function_name).unwrap();
1208            return func.evaluate(&func_args);
1209        }
1210
1211        // If not in registry, the method is not supported
1212        // All methods should be registered in the function registry
1213        Err(anyhow!(
1214            "Method '{}' not found. It should be registered in the function registry.",
1215            method
1216        ))
1217    }
1218
1219    /// Evaluate a CASE expression
1220    fn evaluate_case_expression(
1221        &mut self,
1222        when_branches: &[crate::sql::recursive_parser::WhenBranch],
1223        else_branch: &Option<Box<SqlExpression>>,
1224        row_index: usize,
1225    ) -> Result<DataValue> {
1226        debug!(
1227            "ArithmeticEvaluator: evaluating CASE expression for row {}",
1228            row_index
1229        );
1230
1231        // Evaluate each WHEN condition in order
1232        for branch in when_branches {
1233            // Evaluate the condition as a boolean
1234            let condition_result = self.evaluate_condition_as_bool(&branch.condition, row_index)?;
1235
1236            if condition_result {
1237                debug!("CASE: WHEN condition matched, evaluating result expression");
1238                return self.evaluate(&branch.result, row_index);
1239            }
1240        }
1241
1242        // If no WHEN condition matched, evaluate ELSE clause (or return NULL)
1243        if let Some(else_expr) = else_branch {
1244            debug!("CASE: No WHEN matched, evaluating ELSE expression");
1245            self.evaluate(else_expr, row_index)
1246        } else {
1247            debug!("CASE: No WHEN matched and no ELSE, returning NULL");
1248            Ok(DataValue::Null)
1249        }
1250    }
1251
1252    /// Evaluate a simple CASE expression
1253    fn evaluate_simple_case_expression(
1254        &mut self,
1255        expr: &Box<SqlExpression>,
1256        when_branches: &[crate::sql::parser::ast::SimpleWhenBranch],
1257        else_branch: &Option<Box<SqlExpression>>,
1258        row_index: usize,
1259    ) -> Result<DataValue> {
1260        debug!(
1261            "ArithmeticEvaluator: evaluating simple CASE expression for row {}",
1262            row_index
1263        );
1264
1265        // Evaluate the main expression once
1266        let case_value = self.evaluate(expr, row_index)?;
1267        debug!("Simple CASE: evaluated expression to {:?}", case_value);
1268
1269        // Compare against each WHEN value in order
1270        for branch in when_branches {
1271            // Evaluate the WHEN value
1272            let when_value = self.evaluate(&branch.value, row_index)?;
1273
1274            // Check for equality
1275            if self.values_equal(&case_value, &when_value)? {
1276                debug!("Simple CASE: WHEN value matched, evaluating result expression");
1277                return self.evaluate(&branch.result, row_index);
1278            }
1279        }
1280
1281        // If no WHEN value matched, evaluate ELSE clause (or return NULL)
1282        if let Some(else_expr) = else_branch {
1283            debug!("Simple CASE: No WHEN matched, evaluating ELSE expression");
1284            self.evaluate(else_expr, row_index)
1285        } else {
1286            debug!("Simple CASE: No WHEN matched and no ELSE, returning NULL");
1287            Ok(DataValue::Null)
1288        }
1289    }
1290
1291    /// Check if two DataValues are equal
1292    fn values_equal(&self, left: &DataValue, right: &DataValue) -> Result<bool> {
1293        match (left, right) {
1294            (DataValue::Null, DataValue::Null) => Ok(true),
1295            (DataValue::Null, _) | (_, DataValue::Null) => Ok(false),
1296            (DataValue::Integer(a), DataValue::Integer(b)) => Ok(a == b),
1297            (DataValue::Float(a), DataValue::Float(b)) => Ok((a - b).abs() < f64::EPSILON),
1298            (DataValue::String(a), DataValue::String(b)) => Ok(a == b),
1299            (DataValue::Boolean(a), DataValue::Boolean(b)) => Ok(a == b),
1300            (DataValue::DateTime(a), DataValue::DateTime(b)) => Ok(a == b),
1301            // Type coercion for numeric comparisons
1302            (DataValue::Integer(a), DataValue::Float(b)) => {
1303                Ok((*a as f64 - b).abs() < f64::EPSILON)
1304            }
1305            (DataValue::Float(a), DataValue::Integer(b)) => {
1306                Ok((a - *b as f64).abs() < f64::EPSILON)
1307            }
1308            _ => Ok(false),
1309        }
1310    }
1311
1312    /// Helper method to evaluate an expression as a boolean (for CASE WHEN conditions)
1313    fn evaluate_condition_as_bool(
1314        &mut self,
1315        expr: &SqlExpression,
1316        row_index: usize,
1317    ) -> Result<bool> {
1318        let value = self.evaluate(expr, row_index)?;
1319
1320        match value {
1321            DataValue::Boolean(b) => Ok(b),
1322            DataValue::Integer(i) => Ok(i != 0),
1323            DataValue::Float(f) => Ok(f != 0.0),
1324            DataValue::Null => Ok(false),
1325            DataValue::String(s) => Ok(!s.is_empty()),
1326            DataValue::InternedString(s) => Ok(!s.is_empty()),
1327            _ => Ok(true), // Other types are considered truthy
1328        }
1329    }
1330}
1331
1332#[cfg(test)]
1333mod tests {
1334    use super::*;
1335    use crate::data::datatable::{DataColumn, DataRow};
1336
1337    fn create_test_table() -> DataTable {
1338        let mut table = DataTable::new("test");
1339        table.add_column(DataColumn::new("a"));
1340        table.add_column(DataColumn::new("b"));
1341        table.add_column(DataColumn::new("c"));
1342
1343        table
1344            .add_row(DataRow::new(vec![
1345                DataValue::Integer(10),
1346                DataValue::Float(2.5),
1347                DataValue::Integer(4),
1348            ]))
1349            .unwrap();
1350
1351        table
1352    }
1353
1354    #[test]
1355    fn test_evaluate_column() {
1356        let table = create_test_table();
1357        let mut evaluator = ArithmeticEvaluator::new(&table);
1358
1359        let expr = SqlExpression::Column(ColumnRef::unquoted("a".to_string()));
1360        let result = evaluator.evaluate(&expr, 0).unwrap();
1361        assert_eq!(result, DataValue::Integer(10));
1362    }
1363
1364    #[test]
1365    fn test_evaluate_number_literal() {
1366        let table = create_test_table();
1367        let mut evaluator = ArithmeticEvaluator::new(&table);
1368
1369        let expr = SqlExpression::NumberLiteral("42".to_string());
1370        let result = evaluator.evaluate(&expr, 0).unwrap();
1371        assert_eq!(result, DataValue::Integer(42));
1372
1373        let expr = SqlExpression::NumberLiteral("3.14".to_string());
1374        let result = evaluator.evaluate(&expr, 0).unwrap();
1375        assert_eq!(result, DataValue::Float(3.14));
1376    }
1377
1378    #[test]
1379    fn test_add_values() {
1380        let table = create_test_table();
1381        let mut evaluator = ArithmeticEvaluator::new(&table);
1382
1383        // Integer + Integer
1384        let result = evaluator
1385            .add_values(&DataValue::Integer(5), &DataValue::Integer(3))
1386            .unwrap();
1387        assert_eq!(result, DataValue::Integer(8));
1388
1389        // Integer + Float
1390        let result = evaluator
1391            .add_values(&DataValue::Integer(5), &DataValue::Float(2.5))
1392            .unwrap();
1393        assert_eq!(result, DataValue::Float(7.5));
1394    }
1395
1396    #[test]
1397    fn test_multiply_values() {
1398        let table = create_test_table();
1399        let mut evaluator = ArithmeticEvaluator::new(&table);
1400
1401        // Integer * Float
1402        let result = evaluator
1403            .multiply_values(&DataValue::Integer(4), &DataValue::Float(2.5))
1404            .unwrap();
1405        assert_eq!(result, DataValue::Float(10.0));
1406    }
1407
1408    #[test]
1409    fn test_divide_values() {
1410        let table = create_test_table();
1411        let mut evaluator = ArithmeticEvaluator::new(&table);
1412
1413        // Exact division
1414        let result = evaluator
1415            .divide_values(&DataValue::Integer(10), &DataValue::Integer(2))
1416            .unwrap();
1417        assert_eq!(result, DataValue::Integer(5));
1418
1419        // Non-exact division
1420        let result = evaluator
1421            .divide_values(&DataValue::Integer(10), &DataValue::Integer(3))
1422            .unwrap();
1423        assert_eq!(result, DataValue::Float(10.0 / 3.0));
1424    }
1425
1426    #[test]
1427    fn test_division_by_zero() {
1428        let table = create_test_table();
1429        let mut evaluator = ArithmeticEvaluator::new(&table);
1430
1431        let result = evaluator.divide_values(&DataValue::Integer(10), &DataValue::Integer(0));
1432        assert!(result.is_err());
1433        assert!(result.unwrap_err().to_string().contains("Division by zero"));
1434    }
1435
1436    #[test]
1437    fn test_binary_op_expression() {
1438        let table = create_test_table();
1439        let mut evaluator = ArithmeticEvaluator::new(&table);
1440
1441        // a * b where a=10, b=2.5
1442        let expr = SqlExpression::BinaryOp {
1443            left: Box::new(SqlExpression::Column(ColumnRef::unquoted("a".to_string()))),
1444            op: "*".to_string(),
1445            right: Box::new(SqlExpression::Column(ColumnRef::unquoted("b".to_string()))),
1446        };
1447
1448        let result = evaluator.evaluate(&expr, 0).unwrap();
1449        assert_eq!(result, DataValue::Float(25.0));
1450    }
1451}