sql_cli/data/
arithmetic_evaluator.rs

1use crate::config::global::get_date_notation;
2use crate::data::data_view::DataView;
3use crate::data::datatable::{DataTable, DataValue};
4use crate::data::value_comparisons::compare_with_op;
5use crate::sql::aggregate_functions::AggregateFunctionRegistry; // New registry
6use crate::sql::aggregates::AggregateRegistry; // Old registry (for migration)
7use crate::sql::functions::FunctionRegistry;
8use crate::sql::parser::ast::{ColumnRef, WindowSpec};
9use crate::sql::recursive_parser::SqlExpression;
10use crate::sql::window_context::WindowContext;
11use crate::sql::window_functions::{ExpressionEvaluator, WindowFunctionRegistry};
12use anyhow::{anyhow, Result};
13use std::collections::{HashMap, HashSet};
14use std::sync::Arc;
15use tracing::debug;
16
17/// Evaluates SQL expressions to compute `DataValues` (for SELECT clauses)
18/// This is different from `RecursiveWhereEvaluator` which returns boolean
19pub struct ArithmeticEvaluator<'a> {
20    table: &'a DataTable,
21    _date_notation: String,
22    function_registry: Arc<FunctionRegistry>,
23    aggregate_registry: Arc<AggregateRegistry>, // Old registry (being phased out)
24    new_aggregate_registry: Arc<AggregateFunctionRegistry>, // New registry
25    window_function_registry: Arc<WindowFunctionRegistry>,
26    visible_rows: Option<Vec<usize>>, // For aggregate functions on filtered views
27    window_contexts: HashMap<String, Arc<WindowContext>>, // Cache window contexts by spec
28    table_aliases: HashMap<String, String>, // Map alias -> table name for qualified columns
29}
30
31impl<'a> ArithmeticEvaluator<'a> {
32    #[must_use]
33    pub fn new(table: &'a DataTable) -> Self {
34        Self {
35            table,
36            _date_notation: get_date_notation(),
37            function_registry: Arc::new(FunctionRegistry::new()),
38            aggregate_registry: Arc::new(AggregateRegistry::new()),
39            new_aggregate_registry: Arc::new(AggregateFunctionRegistry::new()),
40            window_function_registry: Arc::new(WindowFunctionRegistry::new()),
41            visible_rows: None,
42            window_contexts: HashMap::new(),
43            table_aliases: HashMap::new(),
44        }
45    }
46
47    #[must_use]
48    pub fn with_date_notation(table: &'a DataTable, date_notation: String) -> Self {
49        Self {
50            table,
51            _date_notation: date_notation,
52            function_registry: Arc::new(FunctionRegistry::new()),
53            aggregate_registry: Arc::new(AggregateRegistry::new()),
54            new_aggregate_registry: Arc::new(AggregateFunctionRegistry::new()),
55            window_function_registry: Arc::new(WindowFunctionRegistry::new()),
56            visible_rows: None,
57            window_contexts: HashMap::new(),
58            table_aliases: HashMap::new(),
59        }
60    }
61
62    /// Set visible rows for aggregate functions (for filtered views)
63    #[must_use]
64    pub fn with_visible_rows(mut self, rows: Vec<usize>) -> Self {
65        self.visible_rows = Some(rows);
66        self
67    }
68
69    /// Set table aliases for qualified column resolution
70    #[must_use]
71    pub fn with_table_aliases(mut self, aliases: HashMap<String, String>) -> Self {
72        self.table_aliases = aliases;
73        self
74    }
75
76    #[must_use]
77    pub fn with_date_notation_and_registry(
78        table: &'a DataTable,
79        date_notation: String,
80        function_registry: Arc<FunctionRegistry>,
81    ) -> Self {
82        Self {
83            table,
84            _date_notation: date_notation,
85            function_registry,
86            aggregate_registry: Arc::new(AggregateRegistry::new()),
87            new_aggregate_registry: Arc::new(AggregateFunctionRegistry::new()),
88            window_function_registry: Arc::new(WindowFunctionRegistry::new()),
89            visible_rows: None,
90            window_contexts: HashMap::new(),
91            table_aliases: HashMap::new(),
92        }
93    }
94
95    /// Find a column name similar to the given name using edit distance
96    fn find_similar_column(&self, name: &str) -> Option<String> {
97        let columns = self.table.column_names();
98        let mut best_match: Option<(String, usize)> = None;
99
100        for col in columns {
101            let distance = self.edit_distance(&col.to_lowercase(), &name.to_lowercase());
102            // Only suggest if distance is small (likely a typo)
103            // Allow up to 3 edits for longer names
104            let max_distance = if name.len() > 10 { 3 } else { 2 };
105            if distance <= max_distance {
106                match &best_match {
107                    None => best_match = Some((col, distance)),
108                    Some((_, best_dist)) if distance < *best_dist => {
109                        best_match = Some((col, distance));
110                    }
111                    _ => {}
112                }
113            }
114        }
115
116        best_match.map(|(name, _)| name)
117    }
118
119    /// Calculate Levenshtein edit distance between two strings
120    fn edit_distance(&self, s1: &str, s2: &str) -> usize {
121        // Use the shared implementation from string_methods
122        crate::sql::functions::string_methods::EditDistanceFunction::calculate_edit_distance(s1, s2)
123    }
124
125    /// Evaluate an SQL expression to produce a `DataValue`
126    pub fn evaluate(&mut self, expr: &SqlExpression, row_index: usize) -> Result<DataValue> {
127        debug!(
128            "ArithmeticEvaluator: evaluating {:?} for row {}",
129            expr, row_index
130        );
131
132        match expr {
133            SqlExpression::Column(column_ref) => self.evaluate_column_ref(column_ref, row_index),
134            SqlExpression::StringLiteral(s) => Ok(DataValue::String(s.clone())),
135            SqlExpression::BooleanLiteral(b) => Ok(DataValue::Boolean(*b)),
136            SqlExpression::NumberLiteral(n) => self.evaluate_number_literal(n),
137            SqlExpression::Null => Ok(DataValue::Null),
138            SqlExpression::BinaryOp { left, op, right } => {
139                self.evaluate_binary_op(left, op, right, row_index)
140            }
141            SqlExpression::FunctionCall {
142                name,
143                args,
144                distinct,
145            } => self.evaluate_function_with_distinct(name, args, *distinct, row_index),
146            SqlExpression::WindowFunction {
147                name,
148                args,
149                window_spec,
150            } => self.evaluate_window_function(name, args, window_spec, row_index),
151            SqlExpression::MethodCall {
152                object,
153                method,
154                args,
155            } => self.evaluate_method_call(object, method, args, row_index),
156            SqlExpression::ChainedMethodCall { base, method, args } => {
157                // Evaluate the base expression first, then apply the method
158                let base_value = self.evaluate(base, row_index)?;
159                self.evaluate_method_on_value(&base_value, method, args, row_index)
160            }
161            SqlExpression::CaseExpression {
162                when_branches,
163                else_branch,
164            } => self.evaluate_case_expression(when_branches, else_branch, row_index),
165            SqlExpression::SimpleCaseExpression {
166                expr,
167                when_branches,
168                else_branch,
169            } => self.evaluate_simple_case_expression(expr, when_branches, else_branch, row_index),
170            SqlExpression::DateTimeConstructor {
171                year,
172                month,
173                day,
174                hour,
175                minute,
176                second,
177            } => self.evaluate_datetime_constructor(*year, *month, *day, *hour, *minute, *second),
178            SqlExpression::DateTimeToday {
179                hour,
180                minute,
181                second,
182            } => self.evaluate_datetime_today(*hour, *minute, *second),
183            _ => Err(anyhow!(
184                "Unsupported expression type for arithmetic evaluation: {:?}",
185                expr
186            )),
187        }
188    }
189
190    /// Evaluate a column reference with proper table scoping
191    fn evaluate_column_ref(&self, column_ref: &ColumnRef, row_index: usize) -> Result<DataValue> {
192        if let Some(table_prefix) = &column_ref.table_prefix {
193            // Resolve alias if it exists in table_aliases map
194            let actual_table = self
195                .table_aliases
196                .get(table_prefix)
197                .map(|s| s.as_str())
198                .unwrap_or(table_prefix);
199
200            // Try qualified lookup with resolved table name
201            let qualified_name = format!("{}.{}", actual_table, column_ref.name);
202
203            if let Some(col_idx) = self.table.find_column_by_qualified_name(&qualified_name) {
204                debug!(
205                    "Resolved {}.{} -> '{}' at index {}",
206                    table_prefix, column_ref.name, qualified_name, col_idx
207                );
208                return self
209                    .table
210                    .get_value(row_index, col_idx)
211                    .ok_or_else(|| anyhow!("Row {} out of bounds", row_index))
212                    .map(|v| v.clone());
213            }
214
215            // Fallback: try unqualified lookup
216            if let Some(col_idx) = self.table.get_column_index(&column_ref.name) {
217                debug!(
218                    "Resolved {}.{} -> unqualified '{}' at index {}",
219                    table_prefix, column_ref.name, column_ref.name, col_idx
220                );
221                return self
222                    .table
223                    .get_value(row_index, col_idx)
224                    .ok_or_else(|| anyhow!("Row {} out of bounds", row_index))
225                    .map(|v| v.clone());
226            }
227
228            // If not found, return error
229            Err(anyhow!(
230                "Column '{}' not found. Table '{}' may not support qualified column names",
231                qualified_name,
232                actual_table
233            ))
234        } else {
235            // Simple column name lookup
236            self.evaluate_column(&column_ref.name, row_index)
237        }
238    }
239
240    /// Evaluate a column reference
241    fn evaluate_column(&self, column_name: &str, row_index: usize) -> Result<DataValue> {
242        // First try to resolve qualified column names (table.column or alias.column)
243        let resolved_column = if column_name.contains('.') {
244            // Split on last dot to handle cases like "schema.table.column"
245            if let Some(dot_pos) = column_name.rfind('.') {
246                let _table_or_alias = &column_name[..dot_pos];
247                let col_name = &column_name[dot_pos + 1..];
248
249                // For now, just use the column name part
250                // In the future, we could validate the table/alias part
251                debug!(
252                    "Resolving qualified column: {} -> {}",
253                    column_name, col_name
254                );
255                col_name.to_string()
256            } else {
257                column_name.to_string()
258            }
259        } else {
260            column_name.to_string()
261        };
262
263        let col_index = if let Some(idx) = self.table.get_column_index(&resolved_column) {
264            idx
265        } else if resolved_column != column_name {
266            // If not found, try the original name
267            if let Some(idx) = self.table.get_column_index(column_name) {
268                idx
269            } else {
270                let suggestion = self.find_similar_column(&resolved_column);
271                return Err(match suggestion {
272                    Some(similar) => anyhow!(
273                        "Column '{}' not found. Did you mean '{}'?",
274                        column_name,
275                        similar
276                    ),
277                    None => anyhow!("Column '{}' not found", column_name),
278                });
279            }
280        } else {
281            let suggestion = self.find_similar_column(&resolved_column);
282            return Err(match suggestion {
283                Some(similar) => anyhow!(
284                    "Column '{}' not found. Did you mean '{}'?",
285                    column_name,
286                    similar
287                ),
288                None => anyhow!("Column '{}' not found", column_name),
289            });
290        };
291
292        if row_index >= self.table.row_count() {
293            return Err(anyhow!("Row index {} out of bounds", row_index));
294        }
295
296        let row = self
297            .table
298            .get_row(row_index)
299            .ok_or_else(|| anyhow!("Row {} not found", row_index))?;
300
301        let value = row
302            .get(col_index)
303            .ok_or_else(|| anyhow!("Column index {} out of bounds for row", col_index))?;
304
305        Ok(value.clone())
306    }
307
308    /// Evaluate a number literal (handles both integers and floats)
309    fn evaluate_number_literal(&self, number_str: &str) -> Result<DataValue> {
310        // Try to parse as integer first
311        if let Ok(int_val) = number_str.parse::<i64>() {
312            return Ok(DataValue::Integer(int_val));
313        }
314
315        // If that fails, try as float
316        if let Ok(float_val) = number_str.parse::<f64>() {
317            return Ok(DataValue::Float(float_val));
318        }
319
320        Err(anyhow!("Invalid number literal: {}", number_str))
321    }
322
323    /// Evaluate a binary operation (arithmetic)
324    fn evaluate_binary_op(
325        &mut self,
326        left: &SqlExpression,
327        op: &str,
328        right: &SqlExpression,
329        row_index: usize,
330    ) -> Result<DataValue> {
331        let left_val = self.evaluate(left, row_index)?;
332        let right_val = self.evaluate(right, row_index)?;
333
334        debug!(
335            "ArithmeticEvaluator: {} {} {}",
336            self.format_value(&left_val),
337            op,
338            self.format_value(&right_val)
339        );
340
341        match op {
342            "+" => self.add_values(&left_val, &right_val),
343            "-" => self.subtract_values(&left_val, &right_val),
344            "*" => self.multiply_values(&left_val, &right_val),
345            "/" => self.divide_values(&left_val, &right_val),
346            "%" => {
347                // Modulo operator - call MOD function
348                let args = vec![left.clone(), right.clone()];
349                self.evaluate_function("MOD", &args, row_index)
350            }
351            // Comparison operators (return boolean results)
352            // Use centralized comparison logic for consistency
353            ">" | "<" | ">=" | "<=" | "=" | "!=" | "<>" => {
354                let result = compare_with_op(&left_val, &right_val, op, false);
355                Ok(DataValue::Boolean(result))
356            }
357            // IS NULL / IS NOT NULL operators
358            "IS NULL" => Ok(DataValue::Boolean(matches!(left_val, DataValue::Null))),
359            "IS NOT NULL" => Ok(DataValue::Boolean(!matches!(left_val, DataValue::Null))),
360            // Logical operators
361            "AND" => {
362                let left_bool = self.to_bool(&left_val)?;
363                let right_bool = self.to_bool(&right_val)?;
364                Ok(DataValue::Boolean(left_bool && right_bool))
365            }
366            "OR" => {
367                let left_bool = self.to_bool(&left_val)?;
368                let right_bool = self.to_bool(&right_val)?;
369                Ok(DataValue::Boolean(left_bool || right_bool))
370            }
371            _ => Err(anyhow!("Unsupported arithmetic operator: {}", op)),
372        }
373    }
374
375    /// Add two `DataValues` with type coercion
376    fn add_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
377        // NULL handling - any operation with NULL returns NULL
378        if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
379            return Ok(DataValue::Null);
380        }
381
382        match (left, right) {
383            (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a + b)),
384            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 + b)),
385            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a + *b as f64)),
386            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a + b)),
387            _ => Err(anyhow!("Cannot add {:?} and {:?}", left, right)),
388        }
389    }
390
391    /// Subtract two `DataValues` with type coercion
392    fn subtract_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
393        // NULL handling - any operation with NULL returns NULL
394        if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
395            return Ok(DataValue::Null);
396        }
397
398        match (left, right) {
399            (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a - b)),
400            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 - b)),
401            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a - *b as f64)),
402            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a - b)),
403            _ => Err(anyhow!("Cannot subtract {:?} and {:?}", left, right)),
404        }
405    }
406
407    /// Multiply two `DataValues` with type coercion
408    fn multiply_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
409        // NULL handling - any operation with NULL returns NULL
410        if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
411            return Ok(DataValue::Null);
412        }
413
414        match (left, right) {
415            (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a * b)),
416            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 * b)),
417            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a * *b as f64)),
418            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a * b)),
419            _ => Err(anyhow!("Cannot multiply {:?} and {:?}", left, right)),
420        }
421    }
422
423    /// Divide two `DataValues` with type coercion
424    fn divide_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
425        // NULL handling - any operation with NULL returns NULL
426        if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
427            return Ok(DataValue::Null);
428        }
429
430        // Check for division by zero first
431        let is_zero = match right {
432            DataValue::Integer(0) => true,
433            DataValue::Float(f) if *f == 0.0 => true, // Only check for exact zero, not epsilon
434            _ => false,
435        };
436
437        if is_zero {
438            return Err(anyhow!("Division by zero"));
439        }
440
441        match (left, right) {
442            (DataValue::Integer(a), DataValue::Integer(b)) => {
443                // Integer division - if result is exact, keep as int, otherwise promote to float
444                if a % b == 0 {
445                    Ok(DataValue::Integer(a / b))
446                } else {
447                    Ok(DataValue::Float(*a as f64 / *b as f64))
448                }
449            }
450            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 / b)),
451            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a / *b as f64)),
452            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a / b)),
453            _ => Err(anyhow!("Cannot divide {:?} and {:?}", left, right)),
454        }
455    }
456
457    /// Format a `DataValue` for debug output
458    fn format_value(&self, value: &DataValue) -> String {
459        match value {
460            DataValue::Integer(i) => i.to_string(),
461            DataValue::Float(f) => f.to_string(),
462            DataValue::String(s) => format!("'{s}'"),
463            _ => format!("{value:?}"),
464        }
465    }
466
467    /// Convert a DataValue to boolean for logical operations
468    fn to_bool(&self, value: &DataValue) -> Result<bool> {
469        match value {
470            DataValue::Boolean(b) => Ok(*b),
471            DataValue::Integer(i) => Ok(*i != 0),
472            DataValue::Float(f) => Ok(*f != 0.0),
473            DataValue::Null => Ok(false),
474            _ => Err(anyhow!("Cannot convert {:?} to boolean", value)),
475        }
476    }
477
478    /// Evaluate a function call
479    fn evaluate_function_with_distinct(
480        &mut self,
481        name: &str,
482        args: &[SqlExpression],
483        distinct: bool,
484        row_index: usize,
485    ) -> Result<DataValue> {
486        // If DISTINCT is specified, handle it specially for aggregate functions
487        if distinct {
488            let name_upper = name.to_uppercase();
489
490            // Check if it's an aggregate function in either registry
491            if self.aggregate_registry.is_aggregate(&name_upper)
492                || self.new_aggregate_registry.contains(&name_upper)
493            {
494                return self.evaluate_aggregate_with_distinct(&name_upper, args, row_index);
495            } else {
496                return Err(anyhow!(
497                    "DISTINCT can only be used with aggregate functions"
498                ));
499            }
500        }
501
502        // Otherwise, use the regular evaluation
503        self.evaluate_function(name, args, row_index)
504    }
505
506    fn evaluate_aggregate_with_distinct(
507        &mut self,
508        name: &str,
509        args: &[SqlExpression],
510        _row_index: usize,
511    ) -> Result<DataValue> {
512        let name_upper = name.to_uppercase();
513
514        // Check new aggregate registry first for migrated functions
515        if self.new_aggregate_registry.get(&name_upper).is_some() {
516            let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
517                visible.clone()
518            } else {
519                (0..self.table.rows.len()).collect()
520            };
521
522            // Collect and deduplicate values for DISTINCT
523            let mut vals = Vec::new();
524            for &row_idx in &rows_to_process {
525                if !args.is_empty() {
526                    let value = self.evaluate(&args[0], row_idx)?;
527                    vals.push(value);
528                }
529            }
530
531            // Deduplicate values
532            let mut seen = HashSet::new();
533            let unique_values: Vec<_> = vals
534                .into_iter()
535                .filter(|v| {
536                    let key = format!("{:?}", v);
537                    seen.insert(key)
538                })
539                .collect();
540
541            // Get the aggregate function from the new registry
542            let agg_func = self.new_aggregate_registry.get(&name_upper).unwrap();
543            let mut state = agg_func.create_state();
544
545            // Use unique values
546            for value in &unique_values {
547                state.accumulate(value)?;
548            }
549
550            return Ok(state.finalize());
551        }
552
553        // Check old aggregate registry (DISTINCT handling)
554        if self.aggregate_registry.get(&name_upper).is_some() {
555            // Determine which rows to process first
556            let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
557                visible.clone()
558            } else {
559                (0..self.table.rows.len()).collect()
560            };
561
562            // Special handling for STRING_AGG with separator parameter
563            if name_upper == "STRING_AGG" && args.len() >= 2 {
564                // STRING_AGG(DISTINCT column, separator)
565                let mut state = crate::sql::aggregates::AggregateState::StringAgg(
566                    // Evaluate the separator (second argument) once
567                    if args.len() >= 2 {
568                        let separator = self.evaluate(&args[1], 0)?; // Separator doesn't depend on row
569                        match separator {
570                            DataValue::String(s) => crate::sql::aggregates::StringAggState::new(&s),
571                            DataValue::InternedString(s) => {
572                                crate::sql::aggregates::StringAggState::new(&s)
573                            }
574                            _ => crate::sql::aggregates::StringAggState::new(","), // Default separator
575                        }
576                    } else {
577                        crate::sql::aggregates::StringAggState::new(",")
578                    },
579                );
580
581                // Evaluate the first argument (column) for each row and accumulate
582                // Handle DISTINCT - use a HashSet to track seen values
583                let mut seen_values = HashSet::new();
584
585                for &row_idx in &rows_to_process {
586                    let value = self.evaluate(&args[0], row_idx)?;
587
588                    // Skip if we've seen this value
589                    if !seen_values.insert(value.clone()) {
590                        continue; // Skip duplicate values
591                    }
592
593                    // Now get the aggregate function and accumulate
594                    let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
595                    agg_func.accumulate(&mut state, &value)?;
596                }
597
598                // Finalize the aggregate
599                let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
600                return Ok(agg_func.finalize(state));
601            }
602
603            // For other aggregates with DISTINCT
604            // Evaluate the argument expression for each row
605            let mut vals = Vec::new();
606            for &row_idx in &rows_to_process {
607                if !args.is_empty() {
608                    let value = self.evaluate(&args[0], row_idx)?;
609                    vals.push(value);
610                }
611            }
612
613            // Deduplicate values for DISTINCT
614            let mut seen = HashSet::new();
615            let mut unique_values = Vec::new();
616            for value in vals {
617                if seen.insert(value.clone()) {
618                    unique_values.push(value);
619                }
620            }
621
622            // Now get the aggregate function and process
623            let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
624            let mut state = agg_func.init();
625
626            // Use unique values
627            for value in &unique_values {
628                agg_func.accumulate(&mut state, value)?;
629            }
630
631            return Ok(agg_func.finalize(state));
632        }
633
634        Err(anyhow!("Unknown aggregate function: {}", name))
635    }
636
637    fn evaluate_function(
638        &mut self,
639        name: &str,
640        args: &[SqlExpression],
641        row_index: usize,
642    ) -> Result<DataValue> {
643        // Check if this is an aggregate function
644        let name_upper = name.to_uppercase();
645
646        // Check new aggregate registry first (for migrated functions)
647        if self.new_aggregate_registry.get(&name_upper).is_some() {
648            // Use new registry for SUM
649            let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
650                visible.clone()
651            } else {
652                (0..self.table.rows.len()).collect()
653            };
654
655            // Get the aggregate function from the new registry
656            let agg_func = self.new_aggregate_registry.get(&name_upper).unwrap();
657            let mut state = agg_func.create_state();
658
659            // Special handling for COUNT(*)
660            if name_upper == "COUNT" || name_upper == "COUNT_STAR" {
661                if args.is_empty()
662                    || (args.len() == 1
663                        && matches!(&args[0], SqlExpression::Column(col) if col.name == "*"))
664                    || (args.len() == 1
665                        && matches!(&args[0], SqlExpression::StringLiteral(s) if s == "*"))
666                {
667                    // COUNT(*) or COUNT_STAR - count all rows
668                    for _ in &rows_to_process {
669                        state.accumulate(&DataValue::Integer(1))?;
670                    }
671                } else {
672                    // COUNT(column) - count non-null values
673                    for &row_idx in &rows_to_process {
674                        let value = self.evaluate(&args[0], row_idx)?;
675                        state.accumulate(&value)?;
676                    }
677                }
678            } else {
679                // Other aggregates - evaluate arguments and accumulate
680                if !args.is_empty() {
681                    for &row_idx in &rows_to_process {
682                        let value = self.evaluate(&args[0], row_idx)?;
683                        state.accumulate(&value)?;
684                    }
685                }
686            }
687
688            return Ok(state.finalize());
689        }
690
691        // Check old aggregate registry (for non-migrated functions)
692        if self.aggregate_registry.get(&name_upper).is_some() {
693            // Determine which rows to process first
694            let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
695                visible.clone()
696            } else {
697                (0..self.table.rows.len()).collect()
698            };
699
700            // Special handling for STRING_AGG with separator parameter
701            if name_upper == "STRING_AGG" && args.len() >= 2 {
702                // STRING_AGG(column, separator) - without DISTINCT (handled separately)
703                let mut state = crate::sql::aggregates::AggregateState::StringAgg(
704                    // Evaluate the separator (second argument) once
705                    if args.len() >= 2 {
706                        let separator = self.evaluate(&args[1], 0)?; // Separator doesn't depend on row
707                        match separator {
708                            DataValue::String(s) => crate::sql::aggregates::StringAggState::new(&s),
709                            DataValue::InternedString(s) => {
710                                crate::sql::aggregates::StringAggState::new(&s)
711                            }
712                            _ => crate::sql::aggregates::StringAggState::new(","), // Default separator
713                        }
714                    } else {
715                        crate::sql::aggregates::StringAggState::new(",")
716                    },
717                );
718
719                // Evaluate the first argument (column) for each row and accumulate
720                for &row_idx in &rows_to_process {
721                    let value = self.evaluate(&args[0], row_idx)?;
722                    // Now get the aggregate function and accumulate
723                    let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
724                    agg_func.accumulate(&mut state, &value)?;
725                }
726
727                // Finalize the aggregate
728                let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
729                return Ok(agg_func.finalize(state));
730            }
731
732            // Evaluate arguments first if needed (to avoid borrow issues)
733            let values = if !args.is_empty()
734                && !(args.len() == 1
735                    && matches!(&args[0], SqlExpression::Column(c) if c.name == "*"))
736            {
737                // Evaluate the argument expression for each row
738                let mut vals = Vec::new();
739                for &row_idx in &rows_to_process {
740                    let value = self.evaluate(&args[0], row_idx)?;
741                    vals.push(value);
742                }
743                Some(vals)
744            } else {
745                None
746            };
747
748            // Now get the aggregate function and process
749            let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
750            let mut state = agg_func.init();
751
752            if let Some(values) = values {
753                // Use evaluated values (DISTINCT is handled in evaluate_aggregate_with_distinct)
754                for value in &values {
755                    agg_func.accumulate(&mut state, value)?;
756                }
757            } else {
758                // COUNT(*) case
759                for _ in &rows_to_process {
760                    agg_func.accumulate(&mut state, &DataValue::Integer(1))?;
761                }
762            }
763
764            return Ok(agg_func.finalize(state));
765        }
766
767        // First check if this function exists in the registry
768        if self.function_registry.get(name).is_some() {
769            // Evaluate all arguments first to avoid borrow issues
770            let mut evaluated_args = Vec::new();
771            for arg in args {
772                evaluated_args.push(self.evaluate(arg, row_index)?);
773            }
774
775            // Get the function and call it
776            let func = self.function_registry.get(name).unwrap();
777            return func.evaluate(&evaluated_args);
778        }
779
780        // If not in registry, return error for unknown function
781        Err(anyhow!("Unknown function: {}", name))
782    }
783
784    /// Get or create a WindowContext for the given specification
785    fn get_or_create_window_context(&mut self, spec: &WindowSpec) -> Result<Arc<WindowContext>> {
786        // Create a key for caching based on the spec
787        let key = format!("{:?}", spec);
788
789        if let Some(context) = self.window_contexts.get(&key) {
790            return Ok(Arc::clone(context));
791        }
792
793        // Create a DataView from the table (with visible rows if filtered)
794        let data_view = if let Some(ref _visible_rows) = self.visible_rows {
795            // Create a filtered view
796            let view = DataView::new(Arc::new(self.table.clone()));
797            // Apply filtering based on visible rows
798            // Note: This is a simplified approach - in production we'd need proper filtering
799            view
800        } else {
801            DataView::new(Arc::new(self.table.clone()))
802        };
803
804        // Create the WindowContext with the full spec (including frame)
805        let context = WindowContext::new_with_spec(Arc::new(data_view), spec.clone())?;
806
807        let context = Arc::new(context);
808        self.window_contexts.insert(key, Arc::clone(&context));
809        Ok(context)
810    }
811
812    /// Evaluate a window function
813    fn evaluate_window_function(
814        &mut self,
815        name: &str,
816        args: &[SqlExpression],
817        spec: &WindowSpec,
818        row_index: usize,
819    ) -> Result<DataValue> {
820        let name_upper = name.to_uppercase();
821
822        // First check if this is a syntactic sugar function in the registry
823        debug!("Looking for window function {} in registry", name_upper);
824        if let Some(window_fn_arc) = self.window_function_registry.get(&name_upper) {
825            debug!("Found window function {} in registry", name_upper);
826
827            // Dereference to get the actual window function
828            let window_fn = window_fn_arc.as_ref();
829
830            // Validate arguments
831            window_fn.validate_args(args)?;
832
833            // Transform the window spec based on the function's requirements
834            let transformed_spec = window_fn.transform_window_spec(spec, args)?;
835
836            // Get or create the window context with the transformed spec
837            let context = self.get_or_create_window_context(&transformed_spec)?;
838
839            // Create an expression evaluator adapter
840            struct EvaluatorAdapter<'a, 'b> {
841                evaluator: &'a mut ArithmeticEvaluator<'b>,
842                row_index: usize,
843            }
844
845            impl<'a, 'b> ExpressionEvaluator for EvaluatorAdapter<'a, 'b> {
846                fn evaluate(
847                    &mut self,
848                    expr: &SqlExpression,
849                    _row_index: usize,
850                ) -> Result<DataValue> {
851                    self.evaluator.evaluate(expr, self.row_index)
852                }
853            }
854
855            let mut adapter = EvaluatorAdapter {
856                evaluator: self,
857                row_index,
858            };
859
860            // Call the window function's compute method
861            return window_fn.compute(&context, row_index, args, &mut adapter);
862        }
863
864        // Fall back to built-in window functions
865        let context = self.get_or_create_window_context(spec)?;
866
867        match name_upper.as_str() {
868            "LAG" => {
869                // LAG(column, offset, default)
870                if args.is_empty() {
871                    return Err(anyhow!("LAG requires at least 1 argument"));
872                }
873
874                // Get column name
875                let column = match &args[0] {
876                    SqlExpression::Column(col) => col.clone(),
877                    _ => return Err(anyhow!("LAG first argument must be a column")),
878                };
879
880                // Get offset (default 1)
881                let offset = if args.len() > 1 {
882                    match self.evaluate(&args[1], row_index)? {
883                        DataValue::Integer(i) => i as i32,
884                        _ => return Err(anyhow!("LAG offset must be an integer")),
885                    }
886                } else {
887                    1
888                };
889
890                // Get value at offset
891                Ok(context
892                    .get_offset_value(row_index, -offset, &column.name)
893                    .unwrap_or(DataValue::Null))
894            }
895            "LEAD" => {
896                // LEAD(column, offset, default)
897                if args.is_empty() {
898                    return Err(anyhow!("LEAD requires at least 1 argument"));
899                }
900
901                // Get column name
902                let column = match &args[0] {
903                    SqlExpression::Column(col) => col.clone(),
904                    _ => return Err(anyhow!("LEAD first argument must be a column")),
905                };
906
907                // Get offset (default 1)
908                let offset = if args.len() > 1 {
909                    match self.evaluate(&args[1], row_index)? {
910                        DataValue::Integer(i) => i as i32,
911                        _ => return Err(anyhow!("LEAD offset must be an integer")),
912                    }
913                } else {
914                    1
915                };
916
917                // Get value at offset
918                Ok(context
919                    .get_offset_value(row_index, offset, &column.name)
920                    .unwrap_or(DataValue::Null))
921            }
922            "ROW_NUMBER" => {
923                // ROW_NUMBER() - no arguments
924                Ok(DataValue::Integer(context.get_row_number(row_index) as i64))
925            }
926            "FIRST_VALUE" => {
927                // FIRST_VALUE(column) OVER (... ROWS ...)
928                if args.is_empty() {
929                    return Err(anyhow!("FIRST_VALUE requires 1 argument"));
930                }
931
932                let column = match &args[0] {
933                    SqlExpression::Column(col) => col.clone(),
934                    _ => return Err(anyhow!("FIRST_VALUE argument must be a column")),
935                };
936
937                // Use frame-aware version if frame is specified
938                if context.has_frame() {
939                    Ok(context
940                        .get_frame_first_value(row_index, &column.name)
941                        .unwrap_or(DataValue::Null))
942                } else {
943                    Ok(context
944                        .get_first_value(row_index, &column.name)
945                        .unwrap_or(DataValue::Null))
946                }
947            }
948            "LAST_VALUE" => {
949                // LAST_VALUE(column) OVER (... ROWS ...)
950                if args.is_empty() {
951                    return Err(anyhow!("LAST_VALUE requires 1 argument"));
952                }
953
954                let column = match &args[0] {
955                    SqlExpression::Column(col) => col.clone(),
956                    _ => return Err(anyhow!("LAST_VALUE argument must be a column")),
957                };
958
959                // Use frame-aware version if frame is specified
960                if context.has_frame() {
961                    Ok(context
962                        .get_frame_last_value(row_index, &column.name)
963                        .unwrap_or(DataValue::Null))
964                } else {
965                    Ok(context
966                        .get_last_value(row_index, &column.name)
967                        .unwrap_or(DataValue::Null))
968                }
969            }
970            "SUM" => {
971                // SUM(column) OVER (PARTITION BY ... ROWS n PRECEDING)
972                if args.is_empty() {
973                    return Err(anyhow!("SUM requires 1 argument"));
974                }
975
976                let column = match &args[0] {
977                    SqlExpression::Column(col) => col.clone(),
978                    _ => return Err(anyhow!("SUM argument must be a column")),
979                };
980
981                // Use frame-aware sum if frame is specified, otherwise use partition sum
982                if context.has_frame() {
983                    Ok(context
984                        .get_frame_sum(row_index, &column.name)
985                        .unwrap_or(DataValue::Null))
986                } else {
987                    Ok(context
988                        .get_partition_sum(row_index, &column.name)
989                        .unwrap_or(DataValue::Null))
990                }
991            }
992            "AVG" => {
993                // AVG(column) OVER (PARTITION BY ... ROWS n PRECEDING)
994                if args.is_empty() {
995                    return Err(anyhow!("AVG requires 1 argument"));
996                }
997
998                let column = match &args[0] {
999                    SqlExpression::Column(col) => col.clone(),
1000                    _ => return Err(anyhow!("AVG argument must be a column")),
1001                };
1002
1003                Ok(context
1004                    .get_frame_avg(row_index, &column.name)
1005                    .unwrap_or(DataValue::Null))
1006            }
1007            "STDDEV" | "STDEV" => {
1008                // STDDEV(column) OVER (PARTITION BY ... ROWS n PRECEDING)
1009                if args.is_empty() {
1010                    return Err(anyhow!("STDDEV requires 1 argument"));
1011                }
1012
1013                let column = match &args[0] {
1014                    SqlExpression::Column(col) => col.clone(),
1015                    _ => return Err(anyhow!("STDDEV argument must be a column")),
1016                };
1017
1018                Ok(context
1019                    .get_frame_stddev(row_index, &column.name)
1020                    .unwrap_or(DataValue::Null))
1021            }
1022            "VARIANCE" | "VAR" => {
1023                // VARIANCE(column) OVER (PARTITION BY ... ROWS n PRECEDING)
1024                if args.is_empty() {
1025                    return Err(anyhow!("VARIANCE requires 1 argument"));
1026                }
1027
1028                let column = match &args[0] {
1029                    SqlExpression::Column(col) => col.clone(),
1030                    _ => return Err(anyhow!("VARIANCE argument must be a column")),
1031                };
1032
1033                Ok(context
1034                    .get_frame_variance(row_index, &column.name)
1035                    .unwrap_or(DataValue::Null))
1036            }
1037            "MIN" => {
1038                // MIN(column) OVER (PARTITION BY ... ROWS n PRECEDING)
1039                if args.is_empty() {
1040                    return Err(anyhow!("MIN requires 1 argument"));
1041                }
1042
1043                let column = match &args[0] {
1044                    SqlExpression::Column(col) => col.clone(),
1045                    _ => return Err(anyhow!("MIN argument must be a column")),
1046                };
1047
1048                let frame_rows = context.get_frame_rows(row_index);
1049                if frame_rows.is_empty() {
1050                    return Ok(DataValue::Null);
1051                }
1052
1053                let source_table = context.source();
1054                let col_idx = source_table
1055                    .get_column_index(&column.name)
1056                    .ok_or_else(|| anyhow!("Column '{}' not found", column.name))?;
1057
1058                let mut min_value: Option<DataValue> = None;
1059                for &row_idx in &frame_rows {
1060                    if let Some(value) = source_table.get_value(row_idx, col_idx) {
1061                        if !matches!(value, DataValue::Null) {
1062                            match &min_value {
1063                                None => min_value = Some(value.clone()),
1064                                Some(current_min) => {
1065                                    if value < current_min {
1066                                        min_value = Some(value.clone());
1067                                    }
1068                                }
1069                            }
1070                        }
1071                    }
1072                }
1073
1074                Ok(min_value.unwrap_or(DataValue::Null))
1075            }
1076            "MAX" => {
1077                // MAX(column) OVER (PARTITION BY ... ROWS n PRECEDING)
1078                if args.is_empty() {
1079                    return Err(anyhow!("MAX requires 1 argument"));
1080                }
1081
1082                let column = match &args[0] {
1083                    SqlExpression::Column(col) => col.clone(),
1084                    _ => return Err(anyhow!("MAX argument must be a column")),
1085                };
1086
1087                let frame_rows = context.get_frame_rows(row_index);
1088                if frame_rows.is_empty() {
1089                    return Ok(DataValue::Null);
1090                }
1091
1092                let source_table = context.source();
1093                let col_idx = source_table
1094                    .get_column_index(&column.name)
1095                    .ok_or_else(|| anyhow!("Column '{}' not found", column.name))?;
1096
1097                let mut max_value: Option<DataValue> = None;
1098                for &row_idx in &frame_rows {
1099                    if let Some(value) = source_table.get_value(row_idx, col_idx) {
1100                        if !matches!(value, DataValue::Null) {
1101                            match &max_value {
1102                                None => max_value = Some(value.clone()),
1103                                Some(current_max) => {
1104                                    if value > current_max {
1105                                        max_value = Some(value.clone());
1106                                    }
1107                                }
1108                            }
1109                        }
1110                    }
1111                }
1112
1113                Ok(max_value.unwrap_or(DataValue::Null))
1114            }
1115            "COUNT" => {
1116                // COUNT(*) or COUNT(column) OVER (PARTITION BY ... ROWS n PRECEDING)
1117                // Use frame-aware count if frame is specified, otherwise use partition count
1118
1119                if args.is_empty() {
1120                    // COUNT(*) OVER (...)
1121                    if context.has_frame() {
1122                        Ok(context
1123                            .get_frame_count(row_index, None)
1124                            .unwrap_or(DataValue::Null))
1125                    } else {
1126                        Ok(context
1127                            .get_partition_count(row_index, None)
1128                            .unwrap_or(DataValue::Null))
1129                    }
1130                } else {
1131                    // Check for COUNT(*)
1132                    let column = match &args[0] {
1133                        SqlExpression::Column(col) => {
1134                            if col.name == "*" {
1135                                // COUNT(*) - count all rows
1136                                if context.has_frame() {
1137                                    return Ok(context
1138                                        .get_frame_count(row_index, None)
1139                                        .unwrap_or(DataValue::Null));
1140                                } else {
1141                                    return Ok(context
1142                                        .get_partition_count(row_index, None)
1143                                        .unwrap_or(DataValue::Null));
1144                                }
1145                            }
1146                            col.clone()
1147                        }
1148                        SqlExpression::StringLiteral(s) if s == "*" => {
1149                            // COUNT(*) as StringLiteral
1150                            if context.has_frame() {
1151                                return Ok(context
1152                                    .get_frame_count(row_index, None)
1153                                    .unwrap_or(DataValue::Null));
1154                            } else {
1155                                return Ok(context
1156                                    .get_partition_count(row_index, None)
1157                                    .unwrap_or(DataValue::Null));
1158                            }
1159                        }
1160                        _ => return Err(anyhow!("COUNT argument must be a column or *")),
1161                    };
1162
1163                    // COUNT(column) - count non-null values
1164                    if context.has_frame() {
1165                        Ok(context
1166                            .get_frame_count(row_index, Some(&column.name))
1167                            .unwrap_or(DataValue::Null))
1168                    } else {
1169                        Ok(context
1170                            .get_partition_count(row_index, Some(&column.name))
1171                            .unwrap_or(DataValue::Null))
1172                    }
1173                }
1174            }
1175            _ => Err(anyhow!("Unknown window function: {}", name)),
1176        }
1177    }
1178
1179    /// Evaluate a method call on a column (e.g., `column.Trim()`)
1180    fn evaluate_method_call(
1181        &mut self,
1182        object: &str,
1183        method: &str,
1184        args: &[SqlExpression],
1185        row_index: usize,
1186    ) -> Result<DataValue> {
1187        // Get column value
1188        let col_index = self.table.get_column_index(object).ok_or_else(|| {
1189            let suggestion = self.find_similar_column(object);
1190            match suggestion {
1191                Some(similar) => {
1192                    anyhow!("Column '{}' not found. Did you mean '{}'?", object, similar)
1193                }
1194                None => anyhow!("Column '{}' not found", object),
1195            }
1196        })?;
1197
1198        let cell_value = self.table.get_value(row_index, col_index).cloned();
1199
1200        self.evaluate_method_on_value(
1201            &cell_value.unwrap_or(DataValue::Null),
1202            method,
1203            args,
1204            row_index,
1205        )
1206    }
1207
1208    /// Evaluate a method on a value
1209    fn evaluate_method_on_value(
1210        &mut self,
1211        value: &DataValue,
1212        method: &str,
1213        args: &[SqlExpression],
1214        row_index: usize,
1215    ) -> Result<DataValue> {
1216        // First, try to proxy the method through the function registry
1217        // Many string methods have corresponding functions (TRIM, LENGTH, CONTAINS, etc.)
1218
1219        // Map method names to function names (case-insensitive matching)
1220        let function_name = match method.to_lowercase().as_str() {
1221            "trim" => "TRIM",
1222            "trimstart" | "trimbegin" => "TRIMSTART",
1223            "trimend" => "TRIMEND",
1224            "length" | "len" => "LENGTH",
1225            "contains" => "CONTAINS",
1226            "startswith" => "STARTSWITH",
1227            "endswith" => "ENDSWITH",
1228            "indexof" => "INDEXOF",
1229            _ => method, // Try the method name as-is
1230        };
1231
1232        // Check if we have this function in the registry
1233        if self.function_registry.get(function_name).is_some() {
1234            debug!(
1235                "Proxying method '{}' through function registry as '{}'",
1236                method, function_name
1237            );
1238
1239            // Prepare arguments: receiver is the first argument, followed by method args
1240            let mut func_args = vec![value.clone()];
1241
1242            // Evaluate method arguments and add them
1243            for arg in args {
1244                func_args.push(self.evaluate(arg, row_index)?);
1245            }
1246
1247            // Get the function and call it
1248            let func = self.function_registry.get(function_name).unwrap();
1249            return func.evaluate(&func_args);
1250        }
1251
1252        // If not in registry, the method is not supported
1253        // All methods should be registered in the function registry
1254        Err(anyhow!(
1255            "Method '{}' not found. It should be registered in the function registry.",
1256            method
1257        ))
1258    }
1259
1260    /// Evaluate a CASE expression
1261    fn evaluate_case_expression(
1262        &mut self,
1263        when_branches: &[crate::sql::recursive_parser::WhenBranch],
1264        else_branch: &Option<Box<SqlExpression>>,
1265        row_index: usize,
1266    ) -> Result<DataValue> {
1267        debug!(
1268            "ArithmeticEvaluator: evaluating CASE expression for row {}",
1269            row_index
1270        );
1271
1272        // Evaluate each WHEN condition in order
1273        for branch in when_branches {
1274            // Evaluate the condition as a boolean
1275            let condition_result = self.evaluate_condition_as_bool(&branch.condition, row_index)?;
1276
1277            if condition_result {
1278                debug!("CASE: WHEN condition matched, evaluating result expression");
1279                return self.evaluate(&branch.result, row_index);
1280            }
1281        }
1282
1283        // If no WHEN condition matched, evaluate ELSE clause (or return NULL)
1284        if let Some(else_expr) = else_branch {
1285            debug!("CASE: No WHEN matched, evaluating ELSE expression");
1286            self.evaluate(else_expr, row_index)
1287        } else {
1288            debug!("CASE: No WHEN matched and no ELSE, returning NULL");
1289            Ok(DataValue::Null)
1290        }
1291    }
1292
1293    /// Evaluate a simple CASE expression
1294    fn evaluate_simple_case_expression(
1295        &mut self,
1296        expr: &Box<SqlExpression>,
1297        when_branches: &[crate::sql::parser::ast::SimpleWhenBranch],
1298        else_branch: &Option<Box<SqlExpression>>,
1299        row_index: usize,
1300    ) -> Result<DataValue> {
1301        debug!(
1302            "ArithmeticEvaluator: evaluating simple CASE expression for row {}",
1303            row_index
1304        );
1305
1306        // Evaluate the main expression once
1307        let case_value = self.evaluate(expr, row_index)?;
1308        debug!("Simple CASE: evaluated expression to {:?}", case_value);
1309
1310        // Compare against each WHEN value in order
1311        for branch in when_branches {
1312            // Evaluate the WHEN value
1313            let when_value = self.evaluate(&branch.value, row_index)?;
1314
1315            // Check for equality
1316            if self.values_equal(&case_value, &when_value)? {
1317                debug!("Simple CASE: WHEN value matched, evaluating result expression");
1318                return self.evaluate(&branch.result, row_index);
1319            }
1320        }
1321
1322        // If no WHEN value matched, evaluate ELSE clause (or return NULL)
1323        if let Some(else_expr) = else_branch {
1324            debug!("Simple CASE: No WHEN matched, evaluating ELSE expression");
1325            self.evaluate(else_expr, row_index)
1326        } else {
1327            debug!("Simple CASE: No WHEN matched and no ELSE, returning NULL");
1328            Ok(DataValue::Null)
1329        }
1330    }
1331
1332    /// Check if two DataValues are equal
1333    fn values_equal(&self, left: &DataValue, right: &DataValue) -> Result<bool> {
1334        match (left, right) {
1335            (DataValue::Null, DataValue::Null) => Ok(true),
1336            (DataValue::Null, _) | (_, DataValue::Null) => Ok(false),
1337            (DataValue::Integer(a), DataValue::Integer(b)) => Ok(a == b),
1338            (DataValue::Float(a), DataValue::Float(b)) => Ok((a - b).abs() < f64::EPSILON),
1339            (DataValue::String(a), DataValue::String(b)) => Ok(a == b),
1340            (DataValue::Boolean(a), DataValue::Boolean(b)) => Ok(a == b),
1341            (DataValue::DateTime(a), DataValue::DateTime(b)) => Ok(a == b),
1342            // Type coercion for numeric comparisons
1343            (DataValue::Integer(a), DataValue::Float(b)) => {
1344                Ok((*a as f64 - b).abs() < f64::EPSILON)
1345            }
1346            (DataValue::Float(a), DataValue::Integer(b)) => {
1347                Ok((a - *b as f64).abs() < f64::EPSILON)
1348            }
1349            _ => Ok(false),
1350        }
1351    }
1352
1353    /// Helper method to evaluate an expression as a boolean (for CASE WHEN conditions)
1354    fn evaluate_condition_as_bool(
1355        &mut self,
1356        expr: &SqlExpression,
1357        row_index: usize,
1358    ) -> Result<bool> {
1359        let value = self.evaluate(expr, row_index)?;
1360
1361        match value {
1362            DataValue::Boolean(b) => Ok(b),
1363            DataValue::Integer(i) => Ok(i != 0),
1364            DataValue::Float(f) => Ok(f != 0.0),
1365            DataValue::Null => Ok(false),
1366            DataValue::String(s) => Ok(!s.is_empty()),
1367            DataValue::InternedString(s) => Ok(!s.is_empty()),
1368            _ => Ok(true), // Other types are considered truthy
1369        }
1370    }
1371
1372    /// Evaluate a DATETIME constructor expression
1373    fn evaluate_datetime_constructor(
1374        &self,
1375        year: i32,
1376        month: u32,
1377        day: u32,
1378        hour: Option<u32>,
1379        minute: Option<u32>,
1380        second: Option<u32>,
1381    ) -> Result<DataValue> {
1382        use chrono::{NaiveDate, TimeZone, Utc};
1383
1384        // Create a NaiveDate
1385        let date = NaiveDate::from_ymd_opt(year, month, day)
1386            .ok_or_else(|| anyhow!("Invalid date: {}-{}-{}", year, month, day))?;
1387
1388        // Create datetime with provided time components or defaults
1389        let hour = hour.unwrap_or(0);
1390        let minute = minute.unwrap_or(0);
1391        let second = second.unwrap_or(0);
1392
1393        let naive_datetime = date
1394            .and_hms_opt(hour, minute, second)
1395            .ok_or_else(|| anyhow!("Invalid time: {}:{}:{}", hour, minute, second))?;
1396
1397        // Convert to UTC DateTime
1398        let datetime = Utc.from_utc_datetime(&naive_datetime);
1399
1400        // Format as string with milliseconds
1401        let datetime_str = datetime.format("%Y-%m-%d %H:%M:%S%.3f").to_string();
1402        Ok(DataValue::String(datetime_str))
1403    }
1404
1405    /// Evaluate a DATETIME.TODAY constructor expression
1406    fn evaluate_datetime_today(
1407        &self,
1408        hour: Option<u32>,
1409        minute: Option<u32>,
1410        second: Option<u32>,
1411    ) -> Result<DataValue> {
1412        use chrono::{TimeZone, Utc};
1413
1414        // Get today's date in UTC
1415        let today = Utc::now().date_naive();
1416
1417        // Create datetime with provided time components or defaults
1418        let hour = hour.unwrap_or(0);
1419        let minute = minute.unwrap_or(0);
1420        let second = second.unwrap_or(0);
1421
1422        let naive_datetime = today
1423            .and_hms_opt(hour, minute, second)
1424            .ok_or_else(|| anyhow!("Invalid time: {}:{}:{}", hour, minute, second))?;
1425
1426        // Convert to UTC DateTime
1427        let datetime = Utc.from_utc_datetime(&naive_datetime);
1428
1429        // Format as string with milliseconds
1430        let datetime_str = datetime.format("%Y-%m-%d %H:%M:%S%.3f").to_string();
1431        Ok(DataValue::String(datetime_str))
1432    }
1433}
1434
1435#[cfg(test)]
1436mod tests {
1437    use super::*;
1438    use crate::data::datatable::{DataColumn, DataRow};
1439
1440    fn create_test_table() -> DataTable {
1441        let mut table = DataTable::new("test");
1442        table.add_column(DataColumn::new("a"));
1443        table.add_column(DataColumn::new("b"));
1444        table.add_column(DataColumn::new("c"));
1445
1446        table
1447            .add_row(DataRow::new(vec![
1448                DataValue::Integer(10),
1449                DataValue::Float(2.5),
1450                DataValue::Integer(4),
1451            ]))
1452            .unwrap();
1453
1454        table
1455    }
1456
1457    #[test]
1458    fn test_evaluate_column() {
1459        let table = create_test_table();
1460        let mut evaluator = ArithmeticEvaluator::new(&table);
1461
1462        let expr = SqlExpression::Column(ColumnRef::unquoted("a".to_string()));
1463        let result = evaluator.evaluate(&expr, 0).unwrap();
1464        assert_eq!(result, DataValue::Integer(10));
1465    }
1466
1467    #[test]
1468    fn test_evaluate_number_literal() {
1469        let table = create_test_table();
1470        let mut evaluator = ArithmeticEvaluator::new(&table);
1471
1472        let expr = SqlExpression::NumberLiteral("42".to_string());
1473        let result = evaluator.evaluate(&expr, 0).unwrap();
1474        assert_eq!(result, DataValue::Integer(42));
1475
1476        let expr = SqlExpression::NumberLiteral("3.14".to_string());
1477        let result = evaluator.evaluate(&expr, 0).unwrap();
1478        assert_eq!(result, DataValue::Float(3.14));
1479    }
1480
1481    #[test]
1482    fn test_add_values() {
1483        let table = create_test_table();
1484        let mut evaluator = ArithmeticEvaluator::new(&table);
1485
1486        // Integer + Integer
1487        let result = evaluator
1488            .add_values(&DataValue::Integer(5), &DataValue::Integer(3))
1489            .unwrap();
1490        assert_eq!(result, DataValue::Integer(8));
1491
1492        // Integer + Float
1493        let result = evaluator
1494            .add_values(&DataValue::Integer(5), &DataValue::Float(2.5))
1495            .unwrap();
1496        assert_eq!(result, DataValue::Float(7.5));
1497    }
1498
1499    #[test]
1500    fn test_multiply_values() {
1501        let table = create_test_table();
1502        let mut evaluator = ArithmeticEvaluator::new(&table);
1503
1504        // Integer * Float
1505        let result = evaluator
1506            .multiply_values(&DataValue::Integer(4), &DataValue::Float(2.5))
1507            .unwrap();
1508        assert_eq!(result, DataValue::Float(10.0));
1509    }
1510
1511    #[test]
1512    fn test_divide_values() {
1513        let table = create_test_table();
1514        let mut evaluator = ArithmeticEvaluator::new(&table);
1515
1516        // Exact division
1517        let result = evaluator
1518            .divide_values(&DataValue::Integer(10), &DataValue::Integer(2))
1519            .unwrap();
1520        assert_eq!(result, DataValue::Integer(5));
1521
1522        // Non-exact division
1523        let result = evaluator
1524            .divide_values(&DataValue::Integer(10), &DataValue::Integer(3))
1525            .unwrap();
1526        assert_eq!(result, DataValue::Float(10.0 / 3.0));
1527    }
1528
1529    #[test]
1530    fn test_division_by_zero() {
1531        let table = create_test_table();
1532        let mut evaluator = ArithmeticEvaluator::new(&table);
1533
1534        let result = evaluator.divide_values(&DataValue::Integer(10), &DataValue::Integer(0));
1535        assert!(result.is_err());
1536        assert!(result.unwrap_err().to_string().contains("Division by zero"));
1537    }
1538
1539    #[test]
1540    fn test_binary_op_expression() {
1541        let table = create_test_table();
1542        let mut evaluator = ArithmeticEvaluator::new(&table);
1543
1544        // a * b where a=10, b=2.5
1545        let expr = SqlExpression::BinaryOp {
1546            left: Box::new(SqlExpression::Column(ColumnRef::unquoted("a".to_string()))),
1547            op: "*".to_string(),
1548            right: Box::new(SqlExpression::Column(ColumnRef::unquoted("b".to_string()))),
1549        };
1550
1551        let result = evaluator.evaluate(&expr, 0).unwrap();
1552        assert_eq!(result, DataValue::Float(25.0));
1553    }
1554}