sql_cli/data/
arithmetic_evaluator.rs

1use crate::config::global::get_date_notation;
2use crate::data::data_view::DataView;
3use crate::data::datatable::{DataTable, DataValue};
4use crate::data::value_comparisons::compare_with_op;
5use crate::sql::aggregate_functions::AggregateFunctionRegistry; // New registry
6use crate::sql::aggregates::AggregateRegistry; // Old registry (for migration)
7use crate::sql::functions::FunctionRegistry;
8use crate::sql::parser::ast::{ColumnRef, WindowSpec};
9use crate::sql::recursive_parser::SqlExpression;
10use crate::sql::window_context::WindowContext;
11use crate::sql::window_functions::{ExpressionEvaluator, WindowFunctionRegistry};
12use anyhow::{anyhow, Result};
13use std::collections::{HashMap, HashSet};
14use std::sync::Arc;
15use tracing::debug;
16
17/// Evaluates SQL expressions to compute `DataValues` (for SELECT clauses)
18/// This is different from `RecursiveWhereEvaluator` which returns boolean
19pub struct ArithmeticEvaluator<'a> {
20    table: &'a DataTable,
21    _date_notation: String,
22    function_registry: Arc<FunctionRegistry>,
23    aggregate_registry: Arc<AggregateRegistry>, // Old registry (being phased out)
24    new_aggregate_registry: Arc<AggregateFunctionRegistry>, // New registry
25    window_function_registry: Arc<WindowFunctionRegistry>,
26    visible_rows: Option<Vec<usize>>, // For aggregate functions on filtered views
27    window_contexts: HashMap<String, Arc<WindowContext>>, // Cache window contexts by spec
28    table_aliases: HashMap<String, String>, // Map alias -> table name for qualified columns
29}
30
31impl<'a> ArithmeticEvaluator<'a> {
32    #[must_use]
33    pub fn new(table: &'a DataTable) -> Self {
34        Self {
35            table,
36            _date_notation: get_date_notation(),
37            function_registry: Arc::new(FunctionRegistry::new()),
38            aggregate_registry: Arc::new(AggregateRegistry::new()),
39            new_aggregate_registry: Arc::new(AggregateFunctionRegistry::new()),
40            window_function_registry: Arc::new(WindowFunctionRegistry::new()),
41            visible_rows: None,
42            window_contexts: HashMap::new(),
43            table_aliases: HashMap::new(),
44        }
45    }
46
47    #[must_use]
48    pub fn with_date_notation(table: &'a DataTable, date_notation: String) -> Self {
49        Self {
50            table,
51            _date_notation: date_notation,
52            function_registry: Arc::new(FunctionRegistry::new()),
53            aggregate_registry: Arc::new(AggregateRegistry::new()),
54            new_aggregate_registry: Arc::new(AggregateFunctionRegistry::new()),
55            window_function_registry: Arc::new(WindowFunctionRegistry::new()),
56            visible_rows: None,
57            window_contexts: HashMap::new(),
58            table_aliases: HashMap::new(),
59        }
60    }
61
62    /// Set visible rows for aggregate functions (for filtered views)
63    #[must_use]
64    pub fn with_visible_rows(mut self, rows: Vec<usize>) -> Self {
65        self.visible_rows = Some(rows);
66        self
67    }
68
69    /// Set table aliases for qualified column resolution
70    #[must_use]
71    pub fn with_table_aliases(mut self, aliases: HashMap<String, String>) -> Self {
72        self.table_aliases = aliases;
73        self
74    }
75
76    #[must_use]
77    pub fn with_date_notation_and_registry(
78        table: &'a DataTable,
79        date_notation: String,
80        function_registry: Arc<FunctionRegistry>,
81    ) -> Self {
82        Self {
83            table,
84            _date_notation: date_notation,
85            function_registry,
86            aggregate_registry: Arc::new(AggregateRegistry::new()),
87            new_aggregate_registry: Arc::new(AggregateFunctionRegistry::new()),
88            window_function_registry: Arc::new(WindowFunctionRegistry::new()),
89            visible_rows: None,
90            window_contexts: HashMap::new(),
91            table_aliases: HashMap::new(),
92        }
93    }
94
95    /// Find a column name similar to the given name using edit distance
96    fn find_similar_column(&self, name: &str) -> Option<String> {
97        let columns = self.table.column_names();
98        let mut best_match: Option<(String, usize)> = None;
99
100        for col in columns {
101            let distance = self.edit_distance(&col.to_lowercase(), &name.to_lowercase());
102            // Only suggest if distance is small (likely a typo)
103            // Allow up to 3 edits for longer names
104            let max_distance = if name.len() > 10 { 3 } else { 2 };
105            if distance <= max_distance {
106                match &best_match {
107                    None => best_match = Some((col, distance)),
108                    Some((_, best_dist)) if distance < *best_dist => {
109                        best_match = Some((col, distance));
110                    }
111                    _ => {}
112                }
113            }
114        }
115
116        best_match.map(|(name, _)| name)
117    }
118
119    /// Calculate Levenshtein edit distance between two strings
120    fn edit_distance(&self, s1: &str, s2: &str) -> usize {
121        // Use the shared implementation from string_methods
122        crate::sql::functions::string_methods::EditDistanceFunction::calculate_edit_distance(s1, s2)
123    }
124
125    /// Evaluate an SQL expression to produce a `DataValue`
126    pub fn evaluate(&mut self, expr: &SqlExpression, row_index: usize) -> Result<DataValue> {
127        debug!(
128            "ArithmeticEvaluator: evaluating {:?} for row {}",
129            expr, row_index
130        );
131
132        match expr {
133            SqlExpression::Column(column_ref) => self.evaluate_column_ref(column_ref, row_index),
134            SqlExpression::StringLiteral(s) => Ok(DataValue::String(s.clone())),
135            SqlExpression::BooleanLiteral(b) => Ok(DataValue::Boolean(*b)),
136            SqlExpression::NumberLiteral(n) => self.evaluate_number_literal(n),
137            SqlExpression::Null => Ok(DataValue::Null),
138            SqlExpression::BinaryOp { left, op, right } => {
139                self.evaluate_binary_op(left, op, right, row_index)
140            }
141            SqlExpression::FunctionCall {
142                name,
143                args,
144                distinct,
145            } => self.evaluate_function_with_distinct(name, args, *distinct, row_index),
146            SqlExpression::WindowFunction {
147                name,
148                args,
149                window_spec,
150            } => self.evaluate_window_function(name, args, window_spec, row_index),
151            SqlExpression::MethodCall {
152                object,
153                method,
154                args,
155            } => self.evaluate_method_call(object, method, args, row_index),
156            SqlExpression::ChainedMethodCall { base, method, args } => {
157                // Evaluate the base expression first, then apply the method
158                let base_value = self.evaluate(base, row_index)?;
159                self.evaluate_method_on_value(&base_value, method, args, row_index)
160            }
161            SqlExpression::CaseExpression {
162                when_branches,
163                else_branch,
164            } => self.evaluate_case_expression(when_branches, else_branch, row_index),
165            SqlExpression::SimpleCaseExpression {
166                expr,
167                when_branches,
168                else_branch,
169            } => self.evaluate_simple_case_expression(expr, when_branches, else_branch, row_index),
170            SqlExpression::DateTimeConstructor {
171                year,
172                month,
173                day,
174                hour,
175                minute,
176                second,
177            } => self.evaluate_datetime_constructor(*year, *month, *day, *hour, *minute, *second),
178            SqlExpression::DateTimeToday {
179                hour,
180                minute,
181                second,
182            } => self.evaluate_datetime_today(*hour, *minute, *second),
183            _ => Err(anyhow!(
184                "Unsupported expression type for arithmetic evaluation: {:?}",
185                expr
186            )),
187        }
188    }
189
190    /// Evaluate a column reference with proper table scoping
191    fn evaluate_column_ref(&self, column_ref: &ColumnRef, row_index: usize) -> Result<DataValue> {
192        if let Some(table_prefix) = &column_ref.table_prefix {
193            // For qualified references, ONLY try qualified lookup - no fallback
194            let qualified_name = format!("{}.{}", table_prefix, column_ref.name);
195
196            if let Some(col_idx) = self.table.find_column_by_qualified_name(&qualified_name) {
197                return self
198                    .table
199                    .get_value(row_index, col_idx)
200                    .ok_or_else(|| anyhow!("Row {} out of bounds", row_index))
201                    .map(|v| v.clone());
202            }
203
204            // If not found, return error - don't fall back
205            Err(anyhow!("Column '{}' not found", qualified_name))
206        } else {
207            // Simple column name lookup
208            self.evaluate_column(&column_ref.name, row_index)
209        }
210    }
211
212    /// Evaluate a column reference
213    fn evaluate_column(&self, column_name: &str, row_index: usize) -> Result<DataValue> {
214        // First try to resolve qualified column names (table.column or alias.column)
215        let resolved_column = if column_name.contains('.') {
216            // Split on last dot to handle cases like "schema.table.column"
217            if let Some(dot_pos) = column_name.rfind('.') {
218                let _table_or_alias = &column_name[..dot_pos];
219                let col_name = &column_name[dot_pos + 1..];
220
221                // For now, just use the column name part
222                // In the future, we could validate the table/alias part
223                debug!(
224                    "Resolving qualified column: {} -> {}",
225                    column_name, col_name
226                );
227                col_name.to_string()
228            } else {
229                column_name.to_string()
230            }
231        } else {
232            column_name.to_string()
233        };
234
235        let col_index = if let Some(idx) = self.table.get_column_index(&resolved_column) {
236            idx
237        } else if resolved_column != column_name {
238            // If not found, try the original name
239            if let Some(idx) = self.table.get_column_index(column_name) {
240                idx
241            } else {
242                let suggestion = self.find_similar_column(&resolved_column);
243                return Err(match suggestion {
244                    Some(similar) => anyhow!(
245                        "Column '{}' not found. Did you mean '{}'?",
246                        column_name,
247                        similar
248                    ),
249                    None => anyhow!("Column '{}' not found", column_name),
250                });
251            }
252        } else {
253            let suggestion = self.find_similar_column(&resolved_column);
254            return Err(match suggestion {
255                Some(similar) => anyhow!(
256                    "Column '{}' not found. Did you mean '{}'?",
257                    column_name,
258                    similar
259                ),
260                None => anyhow!("Column '{}' not found", column_name),
261            });
262        };
263
264        if row_index >= self.table.row_count() {
265            return Err(anyhow!("Row index {} out of bounds", row_index));
266        }
267
268        let row = self
269            .table
270            .get_row(row_index)
271            .ok_or_else(|| anyhow!("Row {} not found", row_index))?;
272
273        let value = row
274            .get(col_index)
275            .ok_or_else(|| anyhow!("Column index {} out of bounds for row", col_index))?;
276
277        Ok(value.clone())
278    }
279
280    /// Evaluate a number literal (handles both integers and floats)
281    fn evaluate_number_literal(&self, number_str: &str) -> Result<DataValue> {
282        // Try to parse as integer first
283        if let Ok(int_val) = number_str.parse::<i64>() {
284            return Ok(DataValue::Integer(int_val));
285        }
286
287        // If that fails, try as float
288        if let Ok(float_val) = number_str.parse::<f64>() {
289            return Ok(DataValue::Float(float_val));
290        }
291
292        Err(anyhow!("Invalid number literal: {}", number_str))
293    }
294
295    /// Evaluate a binary operation (arithmetic)
296    fn evaluate_binary_op(
297        &mut self,
298        left: &SqlExpression,
299        op: &str,
300        right: &SqlExpression,
301        row_index: usize,
302    ) -> Result<DataValue> {
303        let left_val = self.evaluate(left, row_index)?;
304        let right_val = self.evaluate(right, row_index)?;
305
306        debug!(
307            "ArithmeticEvaluator: {} {} {}",
308            self.format_value(&left_val),
309            op,
310            self.format_value(&right_val)
311        );
312
313        match op {
314            "+" => self.add_values(&left_val, &right_val),
315            "-" => self.subtract_values(&left_val, &right_val),
316            "*" => self.multiply_values(&left_val, &right_val),
317            "/" => self.divide_values(&left_val, &right_val),
318            "%" => {
319                // Modulo operator - call MOD function
320                let args = vec![left.clone(), right.clone()];
321                self.evaluate_function("MOD", &args, row_index)
322            }
323            // Comparison operators (return boolean results)
324            // Use centralized comparison logic for consistency
325            ">" | "<" | ">=" | "<=" | "=" | "!=" | "<>" => {
326                let result = compare_with_op(&left_val, &right_val, op, false);
327                Ok(DataValue::Boolean(result))
328            }
329            // IS NULL / IS NOT NULL operators
330            "IS NULL" => Ok(DataValue::Boolean(matches!(left_val, DataValue::Null))),
331            "IS NOT NULL" => Ok(DataValue::Boolean(!matches!(left_val, DataValue::Null))),
332            // Logical operators
333            "AND" => {
334                let left_bool = self.to_bool(&left_val)?;
335                let right_bool = self.to_bool(&right_val)?;
336                Ok(DataValue::Boolean(left_bool && right_bool))
337            }
338            "OR" => {
339                let left_bool = self.to_bool(&left_val)?;
340                let right_bool = self.to_bool(&right_val)?;
341                Ok(DataValue::Boolean(left_bool || right_bool))
342            }
343            _ => Err(anyhow!("Unsupported arithmetic operator: {}", op)),
344        }
345    }
346
347    /// Add two `DataValues` with type coercion
348    fn add_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
349        // NULL handling - any operation with NULL returns NULL
350        if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
351            return Ok(DataValue::Null);
352        }
353
354        match (left, right) {
355            (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a + b)),
356            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 + b)),
357            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a + *b as f64)),
358            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a + b)),
359            _ => Err(anyhow!("Cannot add {:?} and {:?}", left, right)),
360        }
361    }
362
363    /// Subtract two `DataValues` with type coercion
364    fn subtract_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
365        // NULL handling - any operation with NULL returns NULL
366        if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
367            return Ok(DataValue::Null);
368        }
369
370        match (left, right) {
371            (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a - b)),
372            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 - b)),
373            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a - *b as f64)),
374            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a - b)),
375            _ => Err(anyhow!("Cannot subtract {:?} and {:?}", left, right)),
376        }
377    }
378
379    /// Multiply two `DataValues` with type coercion
380    fn multiply_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
381        // NULL handling - any operation with NULL returns NULL
382        if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
383            return Ok(DataValue::Null);
384        }
385
386        match (left, right) {
387            (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a * b)),
388            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 * b)),
389            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a * *b as f64)),
390            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a * b)),
391            _ => Err(anyhow!("Cannot multiply {:?} and {:?}", left, right)),
392        }
393    }
394
395    /// Divide two `DataValues` with type coercion
396    fn divide_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
397        // NULL handling - any operation with NULL returns NULL
398        if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
399            return Ok(DataValue::Null);
400        }
401
402        // Check for division by zero first
403        let is_zero = match right {
404            DataValue::Integer(0) => true,
405            DataValue::Float(f) if *f == 0.0 => true, // Only check for exact zero, not epsilon
406            _ => false,
407        };
408
409        if is_zero {
410            return Err(anyhow!("Division by zero"));
411        }
412
413        match (left, right) {
414            (DataValue::Integer(a), DataValue::Integer(b)) => {
415                // Integer division - if result is exact, keep as int, otherwise promote to float
416                if a % b == 0 {
417                    Ok(DataValue::Integer(a / b))
418                } else {
419                    Ok(DataValue::Float(*a as f64 / *b as f64))
420                }
421            }
422            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 / b)),
423            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a / *b as f64)),
424            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a / b)),
425            _ => Err(anyhow!("Cannot divide {:?} and {:?}", left, right)),
426        }
427    }
428
429    /// Format a `DataValue` for debug output
430    fn format_value(&self, value: &DataValue) -> String {
431        match value {
432            DataValue::Integer(i) => i.to_string(),
433            DataValue::Float(f) => f.to_string(),
434            DataValue::String(s) => format!("'{s}'"),
435            _ => format!("{value:?}"),
436        }
437    }
438
439    /// Convert a DataValue to boolean for logical operations
440    fn to_bool(&self, value: &DataValue) -> Result<bool> {
441        match value {
442            DataValue::Boolean(b) => Ok(*b),
443            DataValue::Integer(i) => Ok(*i != 0),
444            DataValue::Float(f) => Ok(*f != 0.0),
445            DataValue::Null => Ok(false),
446            _ => Err(anyhow!("Cannot convert {:?} to boolean", value)),
447        }
448    }
449
450    /// Evaluate a function call
451    fn evaluate_function_with_distinct(
452        &mut self,
453        name: &str,
454        args: &[SqlExpression],
455        distinct: bool,
456        row_index: usize,
457    ) -> Result<DataValue> {
458        // If DISTINCT is specified, handle it specially for aggregate functions
459        if distinct {
460            let name_upper = name.to_uppercase();
461
462            // Check if it's an aggregate function in either registry
463            if self.aggregate_registry.is_aggregate(&name_upper)
464                || self.new_aggregate_registry.contains(&name_upper)
465            {
466                return self.evaluate_aggregate_with_distinct(&name_upper, args, row_index);
467            } else {
468                return Err(anyhow!(
469                    "DISTINCT can only be used with aggregate functions"
470                ));
471            }
472        }
473
474        // Otherwise, use the regular evaluation
475        self.evaluate_function(name, args, row_index)
476    }
477
478    fn evaluate_aggregate_with_distinct(
479        &mut self,
480        name: &str,
481        args: &[SqlExpression],
482        _row_index: usize,
483    ) -> Result<DataValue> {
484        let name_upper = name.to_uppercase();
485
486        // Check new aggregate registry first for migrated functions
487        if self.new_aggregate_registry.get(&name_upper).is_some() {
488            let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
489                visible.clone()
490            } else {
491                (0..self.table.rows.len()).collect()
492            };
493
494            // Collect and deduplicate values for DISTINCT
495            let mut vals = Vec::new();
496            for &row_idx in &rows_to_process {
497                if !args.is_empty() {
498                    let value = self.evaluate(&args[0], row_idx)?;
499                    vals.push(value);
500                }
501            }
502
503            // Deduplicate values
504            let mut seen = HashSet::new();
505            let unique_values: Vec<_> = vals
506                .into_iter()
507                .filter(|v| {
508                    let key = format!("{:?}", v);
509                    seen.insert(key)
510                })
511                .collect();
512
513            // Get the aggregate function from the new registry
514            let agg_func = self.new_aggregate_registry.get(&name_upper).unwrap();
515            let mut state = agg_func.create_state();
516
517            // Use unique values
518            for value in &unique_values {
519                state.accumulate(value)?;
520            }
521
522            return Ok(state.finalize());
523        }
524
525        // Check old aggregate registry (DISTINCT handling)
526        if self.aggregate_registry.get(&name_upper).is_some() {
527            // Determine which rows to process first
528            let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
529                visible.clone()
530            } else {
531                (0..self.table.rows.len()).collect()
532            };
533
534            // Special handling for STRING_AGG with separator parameter
535            if name_upper == "STRING_AGG" && args.len() >= 2 {
536                // STRING_AGG(DISTINCT column, separator)
537                let mut state = crate::sql::aggregates::AggregateState::StringAgg(
538                    // Evaluate the separator (second argument) once
539                    if args.len() >= 2 {
540                        let separator = self.evaluate(&args[1], 0)?; // Separator doesn't depend on row
541                        match separator {
542                            DataValue::String(s) => crate::sql::aggregates::StringAggState::new(&s),
543                            DataValue::InternedString(s) => {
544                                crate::sql::aggregates::StringAggState::new(&s)
545                            }
546                            _ => crate::sql::aggregates::StringAggState::new(","), // Default separator
547                        }
548                    } else {
549                        crate::sql::aggregates::StringAggState::new(",")
550                    },
551                );
552
553                // Evaluate the first argument (column) for each row and accumulate
554                // Handle DISTINCT - use a HashSet to track seen values
555                let mut seen_values = HashSet::new();
556
557                for &row_idx in &rows_to_process {
558                    let value = self.evaluate(&args[0], row_idx)?;
559
560                    // Skip if we've seen this value
561                    if !seen_values.insert(value.clone()) {
562                        continue; // Skip duplicate values
563                    }
564
565                    // Now get the aggregate function and accumulate
566                    let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
567                    agg_func.accumulate(&mut state, &value)?;
568                }
569
570                // Finalize the aggregate
571                let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
572                return Ok(agg_func.finalize(state));
573            }
574
575            // For other aggregates with DISTINCT
576            // Evaluate the argument expression for each row
577            let mut vals = Vec::new();
578            for &row_idx in &rows_to_process {
579                if !args.is_empty() {
580                    let value = self.evaluate(&args[0], row_idx)?;
581                    vals.push(value);
582                }
583            }
584
585            // Deduplicate values for DISTINCT
586            let mut seen = HashSet::new();
587            let mut unique_values = Vec::new();
588            for value in vals {
589                if seen.insert(value.clone()) {
590                    unique_values.push(value);
591                }
592            }
593
594            // Now get the aggregate function and process
595            let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
596            let mut state = agg_func.init();
597
598            // Use unique values
599            for value in &unique_values {
600                agg_func.accumulate(&mut state, value)?;
601            }
602
603            return Ok(agg_func.finalize(state));
604        }
605
606        Err(anyhow!("Unknown aggregate function: {}", name))
607    }
608
609    fn evaluate_function(
610        &mut self,
611        name: &str,
612        args: &[SqlExpression],
613        row_index: usize,
614    ) -> Result<DataValue> {
615        // Check if this is an aggregate function
616        let name_upper = name.to_uppercase();
617
618        // Check new aggregate registry first (for migrated functions)
619        if self.new_aggregate_registry.get(&name_upper).is_some() {
620            // Use new registry for SUM
621            let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
622                visible.clone()
623            } else {
624                (0..self.table.rows.len()).collect()
625            };
626
627            // Get the aggregate function from the new registry
628            let agg_func = self.new_aggregate_registry.get(&name_upper).unwrap();
629            let mut state = agg_func.create_state();
630
631            // Special handling for COUNT(*)
632            if name_upper == "COUNT" || name_upper == "COUNT_STAR" {
633                if args.is_empty()
634                    || (args.len() == 1
635                        && matches!(&args[0], SqlExpression::Column(col) if col.name == "*"))
636                    || (args.len() == 1
637                        && matches!(&args[0], SqlExpression::StringLiteral(s) if s == "*"))
638                {
639                    // COUNT(*) or COUNT_STAR - count all rows
640                    for _ in &rows_to_process {
641                        state.accumulate(&DataValue::Integer(1))?;
642                    }
643                } else {
644                    // COUNT(column) - count non-null values
645                    for &row_idx in &rows_to_process {
646                        let value = self.evaluate(&args[0], row_idx)?;
647                        state.accumulate(&value)?;
648                    }
649                }
650            } else {
651                // Other aggregates - evaluate arguments and accumulate
652                if !args.is_empty() {
653                    for &row_idx in &rows_to_process {
654                        let value = self.evaluate(&args[0], row_idx)?;
655                        state.accumulate(&value)?;
656                    }
657                }
658            }
659
660            return Ok(state.finalize());
661        }
662
663        // Check old aggregate registry (for non-migrated functions)
664        if self.aggregate_registry.get(&name_upper).is_some() {
665            // Determine which rows to process first
666            let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
667                visible.clone()
668            } else {
669                (0..self.table.rows.len()).collect()
670            };
671
672            // Special handling for STRING_AGG with separator parameter
673            if name_upper == "STRING_AGG" && args.len() >= 2 {
674                // STRING_AGG(column, separator) - without DISTINCT (handled separately)
675                let mut state = crate::sql::aggregates::AggregateState::StringAgg(
676                    // Evaluate the separator (second argument) once
677                    if args.len() >= 2 {
678                        let separator = self.evaluate(&args[1], 0)?; // Separator doesn't depend on row
679                        match separator {
680                            DataValue::String(s) => crate::sql::aggregates::StringAggState::new(&s),
681                            DataValue::InternedString(s) => {
682                                crate::sql::aggregates::StringAggState::new(&s)
683                            }
684                            _ => crate::sql::aggregates::StringAggState::new(","), // Default separator
685                        }
686                    } else {
687                        crate::sql::aggregates::StringAggState::new(",")
688                    },
689                );
690
691                // Evaluate the first argument (column) for each row and accumulate
692                for &row_idx in &rows_to_process {
693                    let value = self.evaluate(&args[0], row_idx)?;
694                    // Now get the aggregate function and accumulate
695                    let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
696                    agg_func.accumulate(&mut state, &value)?;
697                }
698
699                // Finalize the aggregate
700                let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
701                return Ok(agg_func.finalize(state));
702            }
703
704            // Evaluate arguments first if needed (to avoid borrow issues)
705            let values = if !args.is_empty()
706                && !(args.len() == 1
707                    && matches!(&args[0], SqlExpression::Column(c) if c.name == "*"))
708            {
709                // Evaluate the argument expression for each row
710                let mut vals = Vec::new();
711                for &row_idx in &rows_to_process {
712                    let value = self.evaluate(&args[0], row_idx)?;
713                    vals.push(value);
714                }
715                Some(vals)
716            } else {
717                None
718            };
719
720            // Now get the aggregate function and process
721            let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
722            let mut state = agg_func.init();
723
724            if let Some(values) = values {
725                // Use evaluated values (DISTINCT is handled in evaluate_aggregate_with_distinct)
726                for value in &values {
727                    agg_func.accumulate(&mut state, value)?;
728                }
729            } else {
730                // COUNT(*) case
731                for _ in &rows_to_process {
732                    agg_func.accumulate(&mut state, &DataValue::Integer(1))?;
733                }
734            }
735
736            return Ok(agg_func.finalize(state));
737        }
738
739        // First check if this function exists in the registry
740        if self.function_registry.get(name).is_some() {
741            // Evaluate all arguments first to avoid borrow issues
742            let mut evaluated_args = Vec::new();
743            for arg in args {
744                evaluated_args.push(self.evaluate(arg, row_index)?);
745            }
746
747            // Get the function and call it
748            let func = self.function_registry.get(name).unwrap();
749            return func.evaluate(&evaluated_args);
750        }
751
752        // If not in registry, return error for unknown function
753        Err(anyhow!("Unknown function: {}", name))
754    }
755
756    /// Get or create a WindowContext for the given specification
757    fn get_or_create_window_context(&mut self, spec: &WindowSpec) -> Result<Arc<WindowContext>> {
758        // Create a key for caching based on the spec
759        let key = format!("{:?}", spec);
760
761        if let Some(context) = self.window_contexts.get(&key) {
762            return Ok(Arc::clone(context));
763        }
764
765        // Create a DataView from the table (with visible rows if filtered)
766        let data_view = if let Some(ref _visible_rows) = self.visible_rows {
767            // Create a filtered view
768            let view = DataView::new(Arc::new(self.table.clone()));
769            // Apply filtering based on visible rows
770            // Note: This is a simplified approach - in production we'd need proper filtering
771            view
772        } else {
773            DataView::new(Arc::new(self.table.clone()))
774        };
775
776        // Create the WindowContext with the full spec (including frame)
777        let context = WindowContext::new_with_spec(Arc::new(data_view), spec.clone())?;
778
779        let context = Arc::new(context);
780        self.window_contexts.insert(key, Arc::clone(&context));
781        Ok(context)
782    }
783
784    /// Evaluate a window function
785    fn evaluate_window_function(
786        &mut self,
787        name: &str,
788        args: &[SqlExpression],
789        spec: &WindowSpec,
790        row_index: usize,
791    ) -> Result<DataValue> {
792        let name_upper = name.to_uppercase();
793
794        // First check if this is a syntactic sugar function in the registry
795        debug!("Looking for window function {} in registry", name_upper);
796        if let Some(window_fn_arc) = self.window_function_registry.get(&name_upper) {
797            debug!("Found window function {} in registry", name_upper);
798
799            // Dereference to get the actual window function
800            let window_fn = window_fn_arc.as_ref();
801
802            // Validate arguments
803            window_fn.validate_args(args)?;
804
805            // Transform the window spec based on the function's requirements
806            let transformed_spec = window_fn.transform_window_spec(spec, args)?;
807
808            // Get or create the window context with the transformed spec
809            let context = self.get_or_create_window_context(&transformed_spec)?;
810
811            // Create an expression evaluator adapter
812            struct EvaluatorAdapter<'a, 'b> {
813                evaluator: &'a mut ArithmeticEvaluator<'b>,
814                row_index: usize,
815            }
816
817            impl<'a, 'b> ExpressionEvaluator for EvaluatorAdapter<'a, 'b> {
818                fn evaluate(
819                    &mut self,
820                    expr: &SqlExpression,
821                    _row_index: usize,
822                ) -> Result<DataValue> {
823                    self.evaluator.evaluate(expr, self.row_index)
824                }
825            }
826
827            let mut adapter = EvaluatorAdapter {
828                evaluator: self,
829                row_index,
830            };
831
832            // Call the window function's compute method
833            return window_fn.compute(&context, row_index, args, &mut adapter);
834        }
835
836        // Fall back to built-in window functions
837        let context = self.get_or_create_window_context(spec)?;
838
839        match name_upper.as_str() {
840            "LAG" => {
841                // LAG(column, offset, default)
842                if args.is_empty() {
843                    return Err(anyhow!("LAG requires at least 1 argument"));
844                }
845
846                // Get column name
847                let column = match &args[0] {
848                    SqlExpression::Column(col) => col.clone(),
849                    _ => return Err(anyhow!("LAG first argument must be a column")),
850                };
851
852                // Get offset (default 1)
853                let offset = if args.len() > 1 {
854                    match self.evaluate(&args[1], row_index)? {
855                        DataValue::Integer(i) => i as i32,
856                        _ => return Err(anyhow!("LAG offset must be an integer")),
857                    }
858                } else {
859                    1
860                };
861
862                // Get value at offset
863                Ok(context
864                    .get_offset_value(row_index, -offset, &column.name)
865                    .unwrap_or(DataValue::Null))
866            }
867            "LEAD" => {
868                // LEAD(column, offset, default)
869                if args.is_empty() {
870                    return Err(anyhow!("LEAD requires at least 1 argument"));
871                }
872
873                // Get column name
874                let column = match &args[0] {
875                    SqlExpression::Column(col) => col.clone(),
876                    _ => return Err(anyhow!("LEAD first argument must be a column")),
877                };
878
879                // Get offset (default 1)
880                let offset = if args.len() > 1 {
881                    match self.evaluate(&args[1], row_index)? {
882                        DataValue::Integer(i) => i as i32,
883                        _ => return Err(anyhow!("LEAD offset must be an integer")),
884                    }
885                } else {
886                    1
887                };
888
889                // Get value at offset
890                Ok(context
891                    .get_offset_value(row_index, offset, &column.name)
892                    .unwrap_or(DataValue::Null))
893            }
894            "ROW_NUMBER" => {
895                // ROW_NUMBER() - no arguments
896                Ok(DataValue::Integer(context.get_row_number(row_index) as i64))
897            }
898            "FIRST_VALUE" => {
899                // FIRST_VALUE(column) OVER (... ROWS ...)
900                if args.is_empty() {
901                    return Err(anyhow!("FIRST_VALUE requires 1 argument"));
902                }
903
904                let column = match &args[0] {
905                    SqlExpression::Column(col) => col.clone(),
906                    _ => return Err(anyhow!("FIRST_VALUE argument must be a column")),
907                };
908
909                // Use frame-aware version if frame is specified
910                if context.has_frame() {
911                    Ok(context
912                        .get_frame_first_value(row_index, &column.name)
913                        .unwrap_or(DataValue::Null))
914                } else {
915                    Ok(context
916                        .get_first_value(row_index, &column.name)
917                        .unwrap_or(DataValue::Null))
918                }
919            }
920            "LAST_VALUE" => {
921                // LAST_VALUE(column) OVER (... ROWS ...)
922                if args.is_empty() {
923                    return Err(anyhow!("LAST_VALUE requires 1 argument"));
924                }
925
926                let column = match &args[0] {
927                    SqlExpression::Column(col) => col.clone(),
928                    _ => return Err(anyhow!("LAST_VALUE argument must be a column")),
929                };
930
931                // Use frame-aware version if frame is specified
932                if context.has_frame() {
933                    Ok(context
934                        .get_frame_last_value(row_index, &column.name)
935                        .unwrap_or(DataValue::Null))
936                } else {
937                    Ok(context
938                        .get_last_value(row_index, &column.name)
939                        .unwrap_or(DataValue::Null))
940                }
941            }
942            "SUM" => {
943                // SUM(column) OVER (PARTITION BY ... ROWS n PRECEDING)
944                if args.is_empty() {
945                    return Err(anyhow!("SUM requires 1 argument"));
946                }
947
948                let column = match &args[0] {
949                    SqlExpression::Column(col) => col.clone(),
950                    _ => return Err(anyhow!("SUM argument must be a column")),
951                };
952
953                // Use frame-aware sum if frame is specified, otherwise use partition sum
954                if context.has_frame() {
955                    Ok(context
956                        .get_frame_sum(row_index, &column.name)
957                        .unwrap_or(DataValue::Null))
958                } else {
959                    Ok(context
960                        .get_partition_sum(row_index, &column.name)
961                        .unwrap_or(DataValue::Null))
962                }
963            }
964            "AVG" => {
965                // AVG(column) OVER (PARTITION BY ... ROWS n PRECEDING)
966                if args.is_empty() {
967                    return Err(anyhow!("AVG requires 1 argument"));
968                }
969
970                let column = match &args[0] {
971                    SqlExpression::Column(col) => col.clone(),
972                    _ => return Err(anyhow!("AVG argument must be a column")),
973                };
974
975                Ok(context
976                    .get_frame_avg(row_index, &column.name)
977                    .unwrap_or(DataValue::Null))
978            }
979            "STDDEV" | "STDEV" => {
980                // STDDEV(column) OVER (PARTITION BY ... ROWS n PRECEDING)
981                if args.is_empty() {
982                    return Err(anyhow!("STDDEV requires 1 argument"));
983                }
984
985                let column = match &args[0] {
986                    SqlExpression::Column(col) => col.clone(),
987                    _ => return Err(anyhow!("STDDEV argument must be a column")),
988                };
989
990                Ok(context
991                    .get_frame_stddev(row_index, &column.name)
992                    .unwrap_or(DataValue::Null))
993            }
994            "VARIANCE" | "VAR" => {
995                // VARIANCE(column) OVER (PARTITION BY ... ROWS n PRECEDING)
996                if args.is_empty() {
997                    return Err(anyhow!("VARIANCE requires 1 argument"));
998                }
999
1000                let column = match &args[0] {
1001                    SqlExpression::Column(col) => col.clone(),
1002                    _ => return Err(anyhow!("VARIANCE argument must be a column")),
1003                };
1004
1005                Ok(context
1006                    .get_frame_variance(row_index, &column.name)
1007                    .unwrap_or(DataValue::Null))
1008            }
1009            "MIN" => {
1010                // MIN(column) OVER (PARTITION BY ... ROWS n PRECEDING)
1011                if args.is_empty() {
1012                    return Err(anyhow!("MIN requires 1 argument"));
1013                }
1014
1015                let column = match &args[0] {
1016                    SqlExpression::Column(col) => col.clone(),
1017                    _ => return Err(anyhow!("MIN argument must be a column")),
1018                };
1019
1020                let frame_rows = context.get_frame_rows(row_index);
1021                if frame_rows.is_empty() {
1022                    return Ok(DataValue::Null);
1023                }
1024
1025                let source_table = context.source();
1026                let col_idx = source_table
1027                    .get_column_index(&column.name)
1028                    .ok_or_else(|| anyhow!("Column '{}' not found", column.name))?;
1029
1030                let mut min_value: Option<DataValue> = None;
1031                for &row_idx in &frame_rows {
1032                    if let Some(value) = source_table.get_value(row_idx, col_idx) {
1033                        if !matches!(value, DataValue::Null) {
1034                            match &min_value {
1035                                None => min_value = Some(value.clone()),
1036                                Some(current_min) => {
1037                                    if value < current_min {
1038                                        min_value = Some(value.clone());
1039                                    }
1040                                }
1041                            }
1042                        }
1043                    }
1044                }
1045
1046                Ok(min_value.unwrap_or(DataValue::Null))
1047            }
1048            "MAX" => {
1049                // MAX(column) OVER (PARTITION BY ... ROWS n PRECEDING)
1050                if args.is_empty() {
1051                    return Err(anyhow!("MAX requires 1 argument"));
1052                }
1053
1054                let column = match &args[0] {
1055                    SqlExpression::Column(col) => col.clone(),
1056                    _ => return Err(anyhow!("MAX argument must be a column")),
1057                };
1058
1059                let frame_rows = context.get_frame_rows(row_index);
1060                if frame_rows.is_empty() {
1061                    return Ok(DataValue::Null);
1062                }
1063
1064                let source_table = context.source();
1065                let col_idx = source_table
1066                    .get_column_index(&column.name)
1067                    .ok_or_else(|| anyhow!("Column '{}' not found", column.name))?;
1068
1069                let mut max_value: Option<DataValue> = None;
1070                for &row_idx in &frame_rows {
1071                    if let Some(value) = source_table.get_value(row_idx, col_idx) {
1072                        if !matches!(value, DataValue::Null) {
1073                            match &max_value {
1074                                None => max_value = Some(value.clone()),
1075                                Some(current_max) => {
1076                                    if value > current_max {
1077                                        max_value = Some(value.clone());
1078                                    }
1079                                }
1080                            }
1081                        }
1082                    }
1083                }
1084
1085                Ok(max_value.unwrap_or(DataValue::Null))
1086            }
1087            "COUNT" => {
1088                // COUNT(*) or COUNT(column) OVER (PARTITION BY ... ROWS n PRECEDING)
1089                // Use frame-aware count if frame is specified, otherwise use partition count
1090
1091                if args.is_empty() {
1092                    // COUNT(*) OVER (...)
1093                    if context.has_frame() {
1094                        Ok(context
1095                            .get_frame_count(row_index, None)
1096                            .unwrap_or(DataValue::Null))
1097                    } else {
1098                        Ok(context
1099                            .get_partition_count(row_index, None)
1100                            .unwrap_or(DataValue::Null))
1101                    }
1102                } else {
1103                    // Check for COUNT(*)
1104                    let column = match &args[0] {
1105                        SqlExpression::Column(col) => {
1106                            if col.name == "*" {
1107                                // COUNT(*) - count all rows
1108                                if context.has_frame() {
1109                                    return Ok(context
1110                                        .get_frame_count(row_index, None)
1111                                        .unwrap_or(DataValue::Null));
1112                                } else {
1113                                    return Ok(context
1114                                        .get_partition_count(row_index, None)
1115                                        .unwrap_or(DataValue::Null));
1116                                }
1117                            }
1118                            col.clone()
1119                        }
1120                        SqlExpression::StringLiteral(s) if s == "*" => {
1121                            // COUNT(*) as StringLiteral
1122                            if context.has_frame() {
1123                                return Ok(context
1124                                    .get_frame_count(row_index, None)
1125                                    .unwrap_or(DataValue::Null));
1126                            } else {
1127                                return Ok(context
1128                                    .get_partition_count(row_index, None)
1129                                    .unwrap_or(DataValue::Null));
1130                            }
1131                        }
1132                        _ => return Err(anyhow!("COUNT argument must be a column or *")),
1133                    };
1134
1135                    // COUNT(column) - count non-null values
1136                    if context.has_frame() {
1137                        Ok(context
1138                            .get_frame_count(row_index, Some(&column.name))
1139                            .unwrap_or(DataValue::Null))
1140                    } else {
1141                        Ok(context
1142                            .get_partition_count(row_index, Some(&column.name))
1143                            .unwrap_or(DataValue::Null))
1144                    }
1145                }
1146            }
1147            _ => Err(anyhow!("Unknown window function: {}", name)),
1148        }
1149    }
1150
1151    /// Evaluate a method call on a column (e.g., `column.Trim()`)
1152    fn evaluate_method_call(
1153        &mut self,
1154        object: &str,
1155        method: &str,
1156        args: &[SqlExpression],
1157        row_index: usize,
1158    ) -> Result<DataValue> {
1159        // Get column value
1160        let col_index = self.table.get_column_index(object).ok_or_else(|| {
1161            let suggestion = self.find_similar_column(object);
1162            match suggestion {
1163                Some(similar) => {
1164                    anyhow!("Column '{}' not found. Did you mean '{}'?", object, similar)
1165                }
1166                None => anyhow!("Column '{}' not found", object),
1167            }
1168        })?;
1169
1170        let cell_value = self.table.get_value(row_index, col_index).cloned();
1171
1172        self.evaluate_method_on_value(
1173            &cell_value.unwrap_or(DataValue::Null),
1174            method,
1175            args,
1176            row_index,
1177        )
1178    }
1179
1180    /// Evaluate a method on a value
1181    fn evaluate_method_on_value(
1182        &mut self,
1183        value: &DataValue,
1184        method: &str,
1185        args: &[SqlExpression],
1186        row_index: usize,
1187    ) -> Result<DataValue> {
1188        // First, try to proxy the method through the function registry
1189        // Many string methods have corresponding functions (TRIM, LENGTH, CONTAINS, etc.)
1190
1191        // Map method names to function names (case-insensitive matching)
1192        let function_name = match method.to_lowercase().as_str() {
1193            "trim" => "TRIM",
1194            "trimstart" | "trimbegin" => "TRIMSTART",
1195            "trimend" => "TRIMEND",
1196            "length" | "len" => "LENGTH",
1197            "contains" => "CONTAINS",
1198            "startswith" => "STARTSWITH",
1199            "endswith" => "ENDSWITH",
1200            "indexof" => "INDEXOF",
1201            _ => method, // Try the method name as-is
1202        };
1203
1204        // Check if we have this function in the registry
1205        if self.function_registry.get(function_name).is_some() {
1206            debug!(
1207                "Proxying method '{}' through function registry as '{}'",
1208                method, function_name
1209            );
1210
1211            // Prepare arguments: receiver is the first argument, followed by method args
1212            let mut func_args = vec![value.clone()];
1213
1214            // Evaluate method arguments and add them
1215            for arg in args {
1216                func_args.push(self.evaluate(arg, row_index)?);
1217            }
1218
1219            // Get the function and call it
1220            let func = self.function_registry.get(function_name).unwrap();
1221            return func.evaluate(&func_args);
1222        }
1223
1224        // If not in registry, the method is not supported
1225        // All methods should be registered in the function registry
1226        Err(anyhow!(
1227            "Method '{}' not found. It should be registered in the function registry.",
1228            method
1229        ))
1230    }
1231
1232    /// Evaluate a CASE expression
1233    fn evaluate_case_expression(
1234        &mut self,
1235        when_branches: &[crate::sql::recursive_parser::WhenBranch],
1236        else_branch: &Option<Box<SqlExpression>>,
1237        row_index: usize,
1238    ) -> Result<DataValue> {
1239        debug!(
1240            "ArithmeticEvaluator: evaluating CASE expression for row {}",
1241            row_index
1242        );
1243
1244        // Evaluate each WHEN condition in order
1245        for branch in when_branches {
1246            // Evaluate the condition as a boolean
1247            let condition_result = self.evaluate_condition_as_bool(&branch.condition, row_index)?;
1248
1249            if condition_result {
1250                debug!("CASE: WHEN condition matched, evaluating result expression");
1251                return self.evaluate(&branch.result, row_index);
1252            }
1253        }
1254
1255        // If no WHEN condition matched, evaluate ELSE clause (or return NULL)
1256        if let Some(else_expr) = else_branch {
1257            debug!("CASE: No WHEN matched, evaluating ELSE expression");
1258            self.evaluate(else_expr, row_index)
1259        } else {
1260            debug!("CASE: No WHEN matched and no ELSE, returning NULL");
1261            Ok(DataValue::Null)
1262        }
1263    }
1264
1265    /// Evaluate a simple CASE expression
1266    fn evaluate_simple_case_expression(
1267        &mut self,
1268        expr: &Box<SqlExpression>,
1269        when_branches: &[crate::sql::parser::ast::SimpleWhenBranch],
1270        else_branch: &Option<Box<SqlExpression>>,
1271        row_index: usize,
1272    ) -> Result<DataValue> {
1273        debug!(
1274            "ArithmeticEvaluator: evaluating simple CASE expression for row {}",
1275            row_index
1276        );
1277
1278        // Evaluate the main expression once
1279        let case_value = self.evaluate(expr, row_index)?;
1280        debug!("Simple CASE: evaluated expression to {:?}", case_value);
1281
1282        // Compare against each WHEN value in order
1283        for branch in when_branches {
1284            // Evaluate the WHEN value
1285            let when_value = self.evaluate(&branch.value, row_index)?;
1286
1287            // Check for equality
1288            if self.values_equal(&case_value, &when_value)? {
1289                debug!("Simple CASE: WHEN value matched, evaluating result expression");
1290                return self.evaluate(&branch.result, row_index);
1291            }
1292        }
1293
1294        // If no WHEN value matched, evaluate ELSE clause (or return NULL)
1295        if let Some(else_expr) = else_branch {
1296            debug!("Simple CASE: No WHEN matched, evaluating ELSE expression");
1297            self.evaluate(else_expr, row_index)
1298        } else {
1299            debug!("Simple CASE: No WHEN matched and no ELSE, returning NULL");
1300            Ok(DataValue::Null)
1301        }
1302    }
1303
1304    /// Check if two DataValues are equal
1305    fn values_equal(&self, left: &DataValue, right: &DataValue) -> Result<bool> {
1306        match (left, right) {
1307            (DataValue::Null, DataValue::Null) => Ok(true),
1308            (DataValue::Null, _) | (_, DataValue::Null) => Ok(false),
1309            (DataValue::Integer(a), DataValue::Integer(b)) => Ok(a == b),
1310            (DataValue::Float(a), DataValue::Float(b)) => Ok((a - b).abs() < f64::EPSILON),
1311            (DataValue::String(a), DataValue::String(b)) => Ok(a == b),
1312            (DataValue::Boolean(a), DataValue::Boolean(b)) => Ok(a == b),
1313            (DataValue::DateTime(a), DataValue::DateTime(b)) => Ok(a == b),
1314            // Type coercion for numeric comparisons
1315            (DataValue::Integer(a), DataValue::Float(b)) => {
1316                Ok((*a as f64 - b).abs() < f64::EPSILON)
1317            }
1318            (DataValue::Float(a), DataValue::Integer(b)) => {
1319                Ok((a - *b as f64).abs() < f64::EPSILON)
1320            }
1321            _ => Ok(false),
1322        }
1323    }
1324
1325    /// Helper method to evaluate an expression as a boolean (for CASE WHEN conditions)
1326    fn evaluate_condition_as_bool(
1327        &mut self,
1328        expr: &SqlExpression,
1329        row_index: usize,
1330    ) -> Result<bool> {
1331        let value = self.evaluate(expr, row_index)?;
1332
1333        match value {
1334            DataValue::Boolean(b) => Ok(b),
1335            DataValue::Integer(i) => Ok(i != 0),
1336            DataValue::Float(f) => Ok(f != 0.0),
1337            DataValue::Null => Ok(false),
1338            DataValue::String(s) => Ok(!s.is_empty()),
1339            DataValue::InternedString(s) => Ok(!s.is_empty()),
1340            _ => Ok(true), // Other types are considered truthy
1341        }
1342    }
1343
1344    /// Evaluate a DATETIME constructor expression
1345    fn evaluate_datetime_constructor(
1346        &self,
1347        year: i32,
1348        month: u32,
1349        day: u32,
1350        hour: Option<u32>,
1351        minute: Option<u32>,
1352        second: Option<u32>,
1353    ) -> Result<DataValue> {
1354        use chrono::{NaiveDate, TimeZone, Utc};
1355
1356        // Create a NaiveDate
1357        let date = NaiveDate::from_ymd_opt(year, month, day)
1358            .ok_or_else(|| anyhow!("Invalid date: {}-{}-{}", year, month, day))?;
1359
1360        // Create datetime with provided time components or defaults
1361        let hour = hour.unwrap_or(0);
1362        let minute = minute.unwrap_or(0);
1363        let second = second.unwrap_or(0);
1364
1365        let naive_datetime = date
1366            .and_hms_opt(hour, minute, second)
1367            .ok_or_else(|| anyhow!("Invalid time: {}:{}:{}", hour, minute, second))?;
1368
1369        // Convert to UTC DateTime
1370        let datetime = Utc.from_utc_datetime(&naive_datetime);
1371
1372        // Format as string with milliseconds
1373        let datetime_str = datetime.format("%Y-%m-%d %H:%M:%S%.3f").to_string();
1374        Ok(DataValue::String(datetime_str))
1375    }
1376
1377    /// Evaluate a DATETIME.TODAY constructor expression
1378    fn evaluate_datetime_today(
1379        &self,
1380        hour: Option<u32>,
1381        minute: Option<u32>,
1382        second: Option<u32>,
1383    ) -> Result<DataValue> {
1384        use chrono::{TimeZone, Utc};
1385
1386        // Get today's date in UTC
1387        let today = Utc::now().date_naive();
1388
1389        // Create datetime with provided time components or defaults
1390        let hour = hour.unwrap_or(0);
1391        let minute = minute.unwrap_or(0);
1392        let second = second.unwrap_or(0);
1393
1394        let naive_datetime = today
1395            .and_hms_opt(hour, minute, second)
1396            .ok_or_else(|| anyhow!("Invalid time: {}:{}:{}", hour, minute, second))?;
1397
1398        // Convert to UTC DateTime
1399        let datetime = Utc.from_utc_datetime(&naive_datetime);
1400
1401        // Format as string with milliseconds
1402        let datetime_str = datetime.format("%Y-%m-%d %H:%M:%S%.3f").to_string();
1403        Ok(DataValue::String(datetime_str))
1404    }
1405}
1406
1407#[cfg(test)]
1408mod tests {
1409    use super::*;
1410    use crate::data::datatable::{DataColumn, DataRow};
1411
1412    fn create_test_table() -> DataTable {
1413        let mut table = DataTable::new("test");
1414        table.add_column(DataColumn::new("a"));
1415        table.add_column(DataColumn::new("b"));
1416        table.add_column(DataColumn::new("c"));
1417
1418        table
1419            .add_row(DataRow::new(vec![
1420                DataValue::Integer(10),
1421                DataValue::Float(2.5),
1422                DataValue::Integer(4),
1423            ]))
1424            .unwrap();
1425
1426        table
1427    }
1428
1429    #[test]
1430    fn test_evaluate_column() {
1431        let table = create_test_table();
1432        let mut evaluator = ArithmeticEvaluator::new(&table);
1433
1434        let expr = SqlExpression::Column(ColumnRef::unquoted("a".to_string()));
1435        let result = evaluator.evaluate(&expr, 0).unwrap();
1436        assert_eq!(result, DataValue::Integer(10));
1437    }
1438
1439    #[test]
1440    fn test_evaluate_number_literal() {
1441        let table = create_test_table();
1442        let mut evaluator = ArithmeticEvaluator::new(&table);
1443
1444        let expr = SqlExpression::NumberLiteral("42".to_string());
1445        let result = evaluator.evaluate(&expr, 0).unwrap();
1446        assert_eq!(result, DataValue::Integer(42));
1447
1448        let expr = SqlExpression::NumberLiteral("3.14".to_string());
1449        let result = evaluator.evaluate(&expr, 0).unwrap();
1450        assert_eq!(result, DataValue::Float(3.14));
1451    }
1452
1453    #[test]
1454    fn test_add_values() {
1455        let table = create_test_table();
1456        let mut evaluator = ArithmeticEvaluator::new(&table);
1457
1458        // Integer + Integer
1459        let result = evaluator
1460            .add_values(&DataValue::Integer(5), &DataValue::Integer(3))
1461            .unwrap();
1462        assert_eq!(result, DataValue::Integer(8));
1463
1464        // Integer + Float
1465        let result = evaluator
1466            .add_values(&DataValue::Integer(5), &DataValue::Float(2.5))
1467            .unwrap();
1468        assert_eq!(result, DataValue::Float(7.5));
1469    }
1470
1471    #[test]
1472    fn test_multiply_values() {
1473        let table = create_test_table();
1474        let mut evaluator = ArithmeticEvaluator::new(&table);
1475
1476        // Integer * Float
1477        let result = evaluator
1478            .multiply_values(&DataValue::Integer(4), &DataValue::Float(2.5))
1479            .unwrap();
1480        assert_eq!(result, DataValue::Float(10.0));
1481    }
1482
1483    #[test]
1484    fn test_divide_values() {
1485        let table = create_test_table();
1486        let mut evaluator = ArithmeticEvaluator::new(&table);
1487
1488        // Exact division
1489        let result = evaluator
1490            .divide_values(&DataValue::Integer(10), &DataValue::Integer(2))
1491            .unwrap();
1492        assert_eq!(result, DataValue::Integer(5));
1493
1494        // Non-exact division
1495        let result = evaluator
1496            .divide_values(&DataValue::Integer(10), &DataValue::Integer(3))
1497            .unwrap();
1498        assert_eq!(result, DataValue::Float(10.0 / 3.0));
1499    }
1500
1501    #[test]
1502    fn test_division_by_zero() {
1503        let table = create_test_table();
1504        let mut evaluator = ArithmeticEvaluator::new(&table);
1505
1506        let result = evaluator.divide_values(&DataValue::Integer(10), &DataValue::Integer(0));
1507        assert!(result.is_err());
1508        assert!(result.unwrap_err().to_string().contains("Division by zero"));
1509    }
1510
1511    #[test]
1512    fn test_binary_op_expression() {
1513        let table = create_test_table();
1514        let mut evaluator = ArithmeticEvaluator::new(&table);
1515
1516        // a * b where a=10, b=2.5
1517        let expr = SqlExpression::BinaryOp {
1518            left: Box::new(SqlExpression::Column(ColumnRef::unquoted("a".to_string()))),
1519            op: "*".to_string(),
1520            right: Box::new(SqlExpression::Column(ColumnRef::unquoted("b".to_string()))),
1521        };
1522
1523        let result = evaluator.evaluate(&expr, 0).unwrap();
1524        assert_eq!(result, DataValue::Float(25.0));
1525    }
1526}