Skip to main content

sql_cli/data/
arithmetic_evaluator.rs

1use crate::config::global::get_date_notation;
2use crate::data::data_view::DataView;
3use crate::data::datatable::{DataTable, DataValue};
4use crate::data::value_comparisons::compare_with_op;
5use crate::sql::aggregate_functions::AggregateFunctionRegistry; // New registry
6use crate::sql::aggregates::AggregateRegistry; // Old registry (for migration)
7use crate::sql::functions::FunctionRegistry;
8use crate::sql::parser::ast::{ColumnRef, WindowSpec};
9use crate::sql::recursive_parser::SqlExpression;
10use crate::sql::window_context::WindowContext;
11use crate::sql::window_functions::{ExpressionEvaluator, WindowFunctionRegistry};
12use anyhow::{anyhow, Result};
13use std::collections::{HashMap, HashSet};
14use std::sync::Arc;
15use std::time::Instant;
16use tracing::{debug, info};
17
18/// Evaluates SQL expressions to compute `DataValues` (for SELECT clauses)
19/// This is different from `RecursiveWhereEvaluator` which returns boolean
20pub struct ArithmeticEvaluator<'a> {
21    table: &'a DataTable,
22    _date_notation: String,
23    function_registry: Arc<FunctionRegistry>,
24    aggregate_registry: Arc<AggregateRegistry>, // Old registry (being phased out)
25    new_aggregate_registry: Arc<AggregateFunctionRegistry>, // New registry
26    window_function_registry: Arc<WindowFunctionRegistry>,
27    visible_rows: Option<Vec<usize>>, // For aggregate functions on filtered views
28    window_contexts: HashMap<u64, Arc<WindowContext>>, // Cache window contexts by hash
29    table_aliases: HashMap<String, String>, // Map alias -> table name for qualified columns
30}
31
32impl<'a> ArithmeticEvaluator<'a> {
33    #[must_use]
34    pub fn new(table: &'a DataTable) -> Self {
35        Self {
36            table,
37            _date_notation: get_date_notation(),
38            function_registry: Arc::new(FunctionRegistry::new()),
39            aggregate_registry: Arc::new(AggregateRegistry::new()),
40            new_aggregate_registry: Arc::new(AggregateFunctionRegistry::new()),
41            window_function_registry: Arc::new(WindowFunctionRegistry::new()),
42            visible_rows: None,
43            window_contexts: HashMap::new(),
44            table_aliases: HashMap::new(),
45        }
46    }
47
48    #[must_use]
49    pub fn with_date_notation(table: &'a DataTable, date_notation: String) -> Self {
50        Self {
51            table,
52            _date_notation: date_notation,
53            function_registry: Arc::new(FunctionRegistry::new()),
54            aggregate_registry: Arc::new(AggregateRegistry::new()),
55            new_aggregate_registry: Arc::new(AggregateFunctionRegistry::new()),
56            window_function_registry: Arc::new(WindowFunctionRegistry::new()),
57            visible_rows: None,
58            window_contexts: HashMap::new(),
59            table_aliases: HashMap::new(),
60        }
61    }
62
63    /// Set visible rows for aggregate functions (for filtered views)
64    #[must_use]
65    pub fn with_visible_rows(mut self, rows: Vec<usize>) -> Self {
66        self.visible_rows = Some(rows);
67        self
68    }
69
70    /// Set table aliases for qualified column resolution
71    #[must_use]
72    pub fn with_table_aliases(mut self, aliases: HashMap<String, String>) -> Self {
73        self.table_aliases = aliases;
74        self
75    }
76
77    #[must_use]
78    pub fn with_date_notation_and_registry(
79        table: &'a DataTable,
80        date_notation: String,
81        function_registry: Arc<FunctionRegistry>,
82    ) -> Self {
83        Self {
84            table,
85            _date_notation: date_notation,
86            function_registry,
87            aggregate_registry: Arc::new(AggregateRegistry::new()),
88            new_aggregate_registry: Arc::new(AggregateFunctionRegistry::new()),
89            window_function_registry: Arc::new(WindowFunctionRegistry::new()),
90            visible_rows: None,
91            window_contexts: HashMap::new(),
92            table_aliases: HashMap::new(),
93        }
94    }
95
96    /// Find a column name similar to the given name using edit distance
97    fn find_similar_column(&self, name: &str) -> Option<String> {
98        let columns = self.table.column_names();
99        let mut best_match: Option<(String, usize)> = None;
100
101        for col in columns {
102            let distance = self.edit_distance(&col.to_lowercase(), &name.to_lowercase());
103            // Only suggest if distance is small (likely a typo)
104            // Allow up to 3 edits for longer names
105            let max_distance = if name.len() > 10 { 3 } else { 2 };
106            if distance <= max_distance {
107                match &best_match {
108                    None => best_match = Some((col, distance)),
109                    Some((_, best_dist)) if distance < *best_dist => {
110                        best_match = Some((col, distance));
111                    }
112                    _ => {}
113                }
114            }
115        }
116
117        best_match.map(|(name, _)| name)
118    }
119
120    /// Calculate Levenshtein edit distance between two strings
121    fn edit_distance(&self, s1: &str, s2: &str) -> usize {
122        // Use the shared implementation from string_methods
123        crate::sql::functions::string_methods::EditDistanceFunction::calculate_edit_distance(s1, s2)
124    }
125
126    /// Evaluate an SQL expression to produce a `DataValue`
127    pub fn evaluate(&mut self, expr: &SqlExpression, row_index: usize) -> Result<DataValue> {
128        debug!(
129            "ArithmeticEvaluator: evaluating {:?} for row {}",
130            expr, row_index
131        );
132
133        match expr {
134            SqlExpression::Column(column_ref) => self.evaluate_column_ref(column_ref, row_index),
135            SqlExpression::StringLiteral(s) => Ok(DataValue::String(s.clone())),
136            SqlExpression::BooleanLiteral(b) => Ok(DataValue::Boolean(*b)),
137            SqlExpression::NumberLiteral(n) => self.evaluate_number_literal(n),
138            SqlExpression::Null => Ok(DataValue::Null),
139            SqlExpression::BinaryOp { left, op, right } => {
140                self.evaluate_binary_op(left, op, right, row_index)
141            }
142            SqlExpression::FunctionCall {
143                name,
144                args,
145                distinct,
146            } => self.evaluate_function_with_distinct(name, args, *distinct, row_index),
147            SqlExpression::WindowFunction {
148                name,
149                args,
150                window_spec,
151            } => self.evaluate_window_function(name, args, window_spec, row_index),
152            SqlExpression::MethodCall {
153                object,
154                method,
155                args,
156            } => self.evaluate_method_call(object, method, args, row_index),
157            SqlExpression::ChainedMethodCall { base, method, args } => {
158                // Evaluate the base expression first, then apply the method
159                let base_value = self.evaluate(base, row_index)?;
160                self.evaluate_method_on_value(&base_value, method, args, row_index)
161            }
162            SqlExpression::Between { expr, lower, upper } => {
163                let val = self.evaluate(expr, row_index)?;
164                let lo = self.evaluate(lower, row_index)?;
165                let hi = self.evaluate(upper, row_index)?;
166                let ge = compare_with_op(&val, &lo, ">=", false);
167                let le = compare_with_op(&val, &hi, "<=", false);
168                Ok(DataValue::Boolean(ge && le))
169            }
170            SqlExpression::CaseExpression {
171                when_branches,
172                else_branch,
173            } => self.evaluate_case_expression(when_branches, else_branch, row_index),
174            SqlExpression::SimpleCaseExpression {
175                expr,
176                when_branches,
177                else_branch,
178            } => self.evaluate_simple_case_expression(expr, when_branches, else_branch, row_index),
179            SqlExpression::DateTimeConstructor {
180                year,
181                month,
182                day,
183                hour,
184                minute,
185                second,
186            } => self.evaluate_datetime_constructor(*year, *month, *day, *hour, *minute, *second),
187            SqlExpression::DateTimeToday {
188                hour,
189                minute,
190                second,
191            } => self.evaluate_datetime_today(*hour, *minute, *second),
192            _ => Err(anyhow!(
193                "Unsupported expression type for arithmetic evaluation: {:?}",
194                expr
195            )),
196        }
197    }
198
199    /// Evaluate a column reference with proper table scoping
200    fn evaluate_column_ref(&self, column_ref: &ColumnRef, row_index: usize) -> Result<DataValue> {
201        if let Some(table_prefix) = &column_ref.table_prefix {
202            // Resolve alias if it exists in table_aliases map
203            let actual_table = self
204                .table_aliases
205                .get(table_prefix)
206                .map(|s| s.as_str())
207                .unwrap_or(table_prefix);
208
209            // Try qualified lookup with resolved table name
210            let qualified_name = format!("{}.{}", actual_table, column_ref.name);
211
212            if let Some(col_idx) = self.table.find_column_by_qualified_name(&qualified_name) {
213                debug!(
214                    "Resolved {}.{} -> '{}' at index {}",
215                    table_prefix, column_ref.name, qualified_name, col_idx
216                );
217                return self
218                    .table
219                    .get_value(row_index, col_idx)
220                    .ok_or_else(|| anyhow!("Row {} out of bounds", row_index))
221                    .map(|v| v.clone());
222            }
223
224            // Fallback: try unqualified lookup
225            if let Some(col_idx) = self.table.get_column_index(&column_ref.name) {
226                debug!(
227                    "Resolved {}.{} -> unqualified '{}' at index {}",
228                    table_prefix, column_ref.name, column_ref.name, col_idx
229                );
230                return self
231                    .table
232                    .get_value(row_index, col_idx)
233                    .ok_or_else(|| anyhow!("Row {} out of bounds", row_index))
234                    .map(|v| v.clone());
235            }
236
237            // If not found, return error
238            Err(anyhow!(
239                "Column '{}' not found. Table '{}' may not support qualified column names",
240                qualified_name,
241                actual_table
242            ))
243        } else {
244            // Simple column name lookup
245            self.evaluate_column(&column_ref.name, row_index)
246        }
247    }
248
249    /// Evaluate a column reference
250    fn evaluate_column(&self, column_name: &str, row_index: usize) -> Result<DataValue> {
251        // First try to resolve qualified column names (table.column or alias.column)
252        let resolved_column = if column_name.contains('.') {
253            // Split on last dot to handle cases like "schema.table.column"
254            if let Some(dot_pos) = column_name.rfind('.') {
255                let _table_or_alias = &column_name[..dot_pos];
256                let col_name = &column_name[dot_pos + 1..];
257
258                // For now, just use the column name part
259                // In the future, we could validate the table/alias part
260                debug!(
261                    "Resolving qualified column: {} -> {}",
262                    column_name, col_name
263                );
264                col_name.to_string()
265            } else {
266                column_name.to_string()
267            }
268        } else {
269            column_name.to_string()
270        };
271
272        let col_index = if let Some(idx) = self.table.get_column_index(&resolved_column) {
273            idx
274        } else if resolved_column != column_name {
275            // If not found, try the original name
276            if let Some(idx) = self.table.get_column_index(column_name) {
277                idx
278            } else {
279                let suggestion = self.find_similar_column(&resolved_column);
280                return Err(match suggestion {
281                    Some(similar) => anyhow!(
282                        "Column '{}' not found. Did you mean '{}'?",
283                        column_name,
284                        similar
285                    ),
286                    None => anyhow!("Column '{}' not found", column_name),
287                });
288            }
289        } else {
290            let suggestion = self.find_similar_column(&resolved_column);
291            return Err(match suggestion {
292                Some(similar) => anyhow!(
293                    "Column '{}' not found. Did you mean '{}'?",
294                    column_name,
295                    similar
296                ),
297                None => anyhow!("Column '{}' not found", column_name),
298            });
299        };
300
301        if row_index >= self.table.row_count() {
302            return Err(anyhow!("Row index {} out of bounds", row_index));
303        }
304
305        let row = self
306            .table
307            .get_row(row_index)
308            .ok_or_else(|| anyhow!("Row {} not found", row_index))?;
309
310        let value = row
311            .get(col_index)
312            .ok_or_else(|| anyhow!("Column index {} out of bounds for row", col_index))?;
313
314        Ok(value.clone())
315    }
316
317    /// Evaluate a number literal (handles both integers and floats)
318    fn evaluate_number_literal(&self, number_str: &str) -> Result<DataValue> {
319        // Try to parse as integer first
320        if let Ok(int_val) = number_str.parse::<i64>() {
321            return Ok(DataValue::Integer(int_val));
322        }
323
324        // If that fails, try as float
325        if let Ok(float_val) = number_str.parse::<f64>() {
326            return Ok(DataValue::Float(float_val));
327        }
328
329        Err(anyhow!("Invalid number literal: {}", number_str))
330    }
331
332    /// Evaluate a binary operation (arithmetic)
333    fn evaluate_binary_op(
334        &mut self,
335        left: &SqlExpression,
336        op: &str,
337        right: &SqlExpression,
338        row_index: usize,
339    ) -> Result<DataValue> {
340        let left_val = self.evaluate(left, row_index)?;
341        let right_val = self.evaluate(right, row_index)?;
342
343        debug!(
344            "ArithmeticEvaluator: {} {} {}",
345            self.format_value(&left_val),
346            op,
347            self.format_value(&right_val)
348        );
349
350        match op {
351            "+" => self.add_values(&left_val, &right_val),
352            "-" => self.subtract_values(&left_val, &right_val),
353            "*" => self.multiply_values(&left_val, &right_val),
354            "/" => self.divide_values(&left_val, &right_val),
355            "%" => {
356                // Modulo operator - call MOD function
357                let args = vec![left.clone(), right.clone()];
358                self.evaluate_function("MOD", &args, row_index)
359            }
360            // Comparison operators (return boolean results)
361            // Use centralized comparison logic for consistency
362            ">" | "<" | ">=" | "<=" | "=" | "!=" | "<>" => {
363                let result = compare_with_op(&left_val, &right_val, op, false);
364                Ok(DataValue::Boolean(result))
365            }
366            // IS NULL / IS NOT NULL operators
367            "IS NULL" => Ok(DataValue::Boolean(matches!(left_val, DataValue::Null))),
368            "IS NOT NULL" => Ok(DataValue::Boolean(!matches!(left_val, DataValue::Null))),
369            // Logical operators
370            "AND" => {
371                let left_bool = self.to_bool(&left_val)?;
372                let right_bool = self.to_bool(&right_val)?;
373                Ok(DataValue::Boolean(left_bool && right_bool))
374            }
375            "OR" => {
376                let left_bool = self.to_bool(&left_val)?;
377                let right_bool = self.to_bool(&right_val)?;
378                Ok(DataValue::Boolean(left_bool || right_bool))
379            }
380            // LIKE operator - SQL pattern matching
381            "LIKE" => {
382                let text = self.value_to_string(&left_val);
383                let pattern = self.value_to_string(&right_val);
384                let matches = self.sql_like_match(&text, &pattern);
385                Ok(DataValue::Boolean(matches))
386            }
387            _ => Err(anyhow!("Unsupported arithmetic operator: {}", op)),
388        }
389    }
390
391    /// Add two `DataValues` with type coercion
392    fn add_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
393        // NULL handling - any operation with NULL returns NULL
394        if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
395            return Ok(DataValue::Null);
396        }
397
398        match (left, right) {
399            (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a + b)),
400            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 + b)),
401            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a + *b as f64)),
402            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a + b)),
403            _ => Err(anyhow!("Cannot add {:?} and {:?}", left, right)),
404        }
405    }
406
407    /// Subtract two `DataValues` with type coercion
408    fn subtract_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
409        // NULL handling - any operation with NULL returns NULL
410        if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
411            return Ok(DataValue::Null);
412        }
413
414        match (left, right) {
415            (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a - b)),
416            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 - b)),
417            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a - *b as f64)),
418            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a - b)),
419            _ => Err(anyhow!("Cannot subtract {:?} and {:?}", left, right)),
420        }
421    }
422
423    /// Multiply two `DataValues` with type coercion
424    fn multiply_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
425        // NULL handling - any operation with NULL returns NULL
426        if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
427            return Ok(DataValue::Null);
428        }
429
430        match (left, right) {
431            (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a * b)),
432            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 * b)),
433            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a * *b as f64)),
434            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a * b)),
435            _ => Err(anyhow!("Cannot multiply {:?} and {:?}", left, right)),
436        }
437    }
438
439    /// Divide two `DataValues` with type coercion
440    fn divide_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
441        // NULL handling - any operation with NULL returns NULL
442        if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
443            return Ok(DataValue::Null);
444        }
445
446        // Check for division by zero first
447        let is_zero = match right {
448            DataValue::Integer(0) => true,
449            DataValue::Float(f) if *f == 0.0 => true, // Only check for exact zero, not epsilon
450            _ => false,
451        };
452
453        if is_zero {
454            return Err(anyhow!("Division by zero"));
455        }
456
457        match (left, right) {
458            (DataValue::Integer(a), DataValue::Integer(b)) => {
459                // Integer division - if result is exact, keep as int, otherwise promote to float
460                if a % b == 0 {
461                    Ok(DataValue::Integer(a / b))
462                } else {
463                    Ok(DataValue::Float(*a as f64 / *b as f64))
464                }
465            }
466            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 / b)),
467            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a / *b as f64)),
468            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a / b)),
469            _ => Err(anyhow!("Cannot divide {:?} and {:?}", left, right)),
470        }
471    }
472
473    /// Format a `DataValue` for debug output
474    fn format_value(&self, value: &DataValue) -> String {
475        match value {
476            DataValue::Integer(i) => i.to_string(),
477            DataValue::Float(f) => f.to_string(),
478            DataValue::String(s) => format!("'{s}'"),
479            _ => format!("{value:?}"),
480        }
481    }
482
483    /// Convert a DataValue to boolean for logical operations
484    fn to_bool(&self, value: &DataValue) -> Result<bool> {
485        match value {
486            DataValue::Boolean(b) => Ok(*b),
487            DataValue::Integer(i) => Ok(*i != 0),
488            DataValue::Float(f) => Ok(*f != 0.0),
489            DataValue::Null => Ok(false),
490            _ => Err(anyhow!("Cannot convert {:?} to boolean", value)),
491        }
492    }
493
494    /// Convert DataValue to string for pattern matching
495    fn value_to_string(&self, value: &DataValue) -> String {
496        match value {
497            DataValue::String(s) => s.clone(),
498            DataValue::InternedString(s) => s.to_string(),
499            DataValue::Integer(i) => i.to_string(),
500            DataValue::Float(f) => f.to_string(),
501            DataValue::Boolean(b) => b.to_string(),
502            DataValue::DateTime(dt) => dt.to_string(),
503            DataValue::Vector(v) => {
504                // Format as "[x,y,z]"
505                let components: Vec<String> = v.iter().map(|f| f.to_string()).collect();
506                format!("[{}]", components.join(","))
507            }
508            DataValue::Null => String::new(),
509        }
510    }
511
512    /// SQL LIKE pattern matching
513    /// Supports % (any chars) and _ (single char)
514    fn sql_like_match(&self, text: &str, pattern: &str) -> bool {
515        let pattern_chars: Vec<char> = pattern.chars().collect();
516        let text_chars: Vec<char> = text.chars().collect();
517
518        self.like_match_recursive(&text_chars, 0, &pattern_chars, 0)
519    }
520
521    /// Recursive helper for LIKE matching
522    fn like_match_recursive(
523        &self,
524        text: &[char],
525        text_pos: usize,
526        pattern: &[char],
527        pattern_pos: usize,
528    ) -> bool {
529        // If we've consumed both text and pattern, it's a match
530        if pattern_pos >= pattern.len() {
531            return text_pos >= text.len();
532        }
533
534        // Handle % wildcard (matches zero or more characters)
535        if pattern[pattern_pos] == '%' {
536            // Try matching zero characters (skip the %)
537            if self.like_match_recursive(text, text_pos, pattern, pattern_pos + 1) {
538                return true;
539            }
540            // Try matching one or more characters
541            if text_pos < text.len() {
542                return self.like_match_recursive(text, text_pos + 1, pattern, pattern_pos);
543            }
544            return false;
545        }
546
547        // If text is consumed but pattern isn't, no match
548        if text_pos >= text.len() {
549            return false;
550        }
551
552        // Handle _ wildcard (matches exactly one character)
553        if pattern[pattern_pos] == '_' {
554            return self.like_match_recursive(text, text_pos + 1, pattern, pattern_pos + 1);
555        }
556
557        // Handle literal character match
558        if text[text_pos] == pattern[pattern_pos] {
559            return self.like_match_recursive(text, text_pos + 1, pattern, pattern_pos + 1);
560        }
561
562        false
563    }
564
565    /// Evaluate a function call
566    fn evaluate_function_with_distinct(
567        &mut self,
568        name: &str,
569        args: &[SqlExpression],
570        distinct: bool,
571        row_index: usize,
572    ) -> Result<DataValue> {
573        // If DISTINCT is specified, handle it specially for aggregate functions
574        if distinct {
575            let name_upper = name.to_uppercase();
576
577            // Check if it's an aggregate function in either registry
578            if self.aggregate_registry.is_aggregate(&name_upper)
579                || self.new_aggregate_registry.contains(&name_upper)
580            {
581                return self.evaluate_aggregate_with_distinct(&name_upper, args, row_index);
582            } else {
583                return Err(anyhow!(
584                    "DISTINCT can only be used with aggregate functions"
585                ));
586            }
587        }
588
589        // Otherwise, use the regular evaluation
590        self.evaluate_function(name, args, row_index)
591    }
592
593    fn evaluate_aggregate_with_distinct(
594        &mut self,
595        name: &str,
596        args: &[SqlExpression],
597        _row_index: usize,
598    ) -> Result<DataValue> {
599        let name_upper = name.to_uppercase();
600
601        // Check new aggregate registry first for migrated functions
602        if self.new_aggregate_registry.get(&name_upper).is_some() {
603            let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
604                visible.clone()
605            } else {
606                (0..self.table.rows.len()).collect()
607            };
608
609            // Collect and deduplicate values for DISTINCT
610            let mut vals = Vec::new();
611            for &row_idx in &rows_to_process {
612                if !args.is_empty() {
613                    let value = self.evaluate(&args[0], row_idx)?;
614                    vals.push(value);
615                }
616            }
617
618            // Deduplicate values
619            let mut seen = HashSet::new();
620            let unique_values: Vec<_> = vals
621                .into_iter()
622                .filter(|v| {
623                    let key = format!("{:?}", v);
624                    seen.insert(key)
625                })
626                .collect();
627
628            // Get the aggregate function from the new registry
629            let agg_func = self.new_aggregate_registry.get(&name_upper).unwrap();
630            let mut state = agg_func.create_state();
631
632            // Use unique values
633            for value in &unique_values {
634                state.accumulate(value)?;
635            }
636
637            return Ok(state.finalize());
638        }
639
640        // Check old aggregate registry (DISTINCT handling)
641        if self.aggregate_registry.get(&name_upper).is_some() {
642            // Determine which rows to process first
643            let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
644                visible.clone()
645            } else {
646                (0..self.table.rows.len()).collect()
647            };
648
649            // Special handling for STRING_AGG with separator parameter
650            if name_upper == "STRING_AGG" && args.len() >= 2 {
651                // STRING_AGG(DISTINCT column, separator)
652                let mut state = crate::sql::aggregates::AggregateState::StringAgg(
653                    // Evaluate the separator (second argument) once
654                    if args.len() >= 2 {
655                        let separator = self.evaluate(&args[1], 0)?; // Separator doesn't depend on row
656                        match separator {
657                            DataValue::String(s) => crate::sql::aggregates::StringAggState::new(&s),
658                            DataValue::InternedString(s) => {
659                                crate::sql::aggregates::StringAggState::new(&s)
660                            }
661                            _ => crate::sql::aggregates::StringAggState::new(","), // Default separator
662                        }
663                    } else {
664                        crate::sql::aggregates::StringAggState::new(",")
665                    },
666                );
667
668                // Evaluate the first argument (column) for each row and accumulate
669                // Handle DISTINCT - use a HashSet to track seen values
670                let mut seen_values = HashSet::new();
671
672                for &row_idx in &rows_to_process {
673                    let value = self.evaluate(&args[0], row_idx)?;
674
675                    // Skip if we've seen this value
676                    if !seen_values.insert(value.clone()) {
677                        continue; // Skip duplicate values
678                    }
679
680                    // Now get the aggregate function and accumulate
681                    let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
682                    agg_func.accumulate(&mut state, &value)?;
683                }
684
685                // Finalize the aggregate
686                let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
687                return Ok(agg_func.finalize(state));
688            }
689
690            // For other aggregates with DISTINCT
691            // Evaluate the argument expression for each row
692            let mut vals = Vec::new();
693            for &row_idx in &rows_to_process {
694                if !args.is_empty() {
695                    let value = self.evaluate(&args[0], row_idx)?;
696                    vals.push(value);
697                }
698            }
699
700            // Deduplicate values for DISTINCT
701            let mut seen = HashSet::new();
702            let mut unique_values = Vec::new();
703            for value in vals {
704                if seen.insert(value.clone()) {
705                    unique_values.push(value);
706                }
707            }
708
709            // Now get the aggregate function and process
710            let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
711            let mut state = agg_func.init();
712
713            // Use unique values
714            for value in &unique_values {
715                agg_func.accumulate(&mut state, value)?;
716            }
717
718            return Ok(agg_func.finalize(state));
719        }
720
721        Err(anyhow!("Unknown aggregate function: {}", name))
722    }
723
724    fn evaluate_function(
725        &mut self,
726        name: &str,
727        args: &[SqlExpression],
728        row_index: usize,
729    ) -> Result<DataValue> {
730        // Check if this is an aggregate function
731        let name_upper = name.to_uppercase();
732
733        // Check new aggregate registry first (for migrated functions)
734        if self.new_aggregate_registry.get(&name_upper).is_some() {
735            // Use new registry for SUM
736            let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
737                visible.clone()
738            } else {
739                (0..self.table.rows.len()).collect()
740            };
741
742            // Get the aggregate function from the new registry
743            let agg_func = self.new_aggregate_registry.get(&name_upper).unwrap();
744            let mut state = agg_func.create_state();
745
746            // Special handling for COUNT(*)
747            if name_upper == "COUNT" || name_upper == "COUNT_STAR" {
748                if args.is_empty()
749                    || (args.len() == 1
750                        && matches!(&args[0], SqlExpression::Column(col) if col.name == "*"))
751                    || (args.len() == 1
752                        && matches!(&args[0], SqlExpression::StringLiteral(s) if s == "*"))
753                {
754                    // COUNT(*) or COUNT_STAR - count all rows
755                    for _ in &rows_to_process {
756                        state.accumulate(&DataValue::Integer(1))?;
757                    }
758                } else {
759                    // COUNT(column) - count non-null values
760                    for &row_idx in &rows_to_process {
761                        let value = self.evaluate(&args[0], row_idx)?;
762                        state.accumulate(&value)?;
763                    }
764                }
765            } else {
766                // Other aggregates - evaluate arguments and accumulate
767                if !args.is_empty() {
768                    for &row_idx in &rows_to_process {
769                        let value = self.evaluate(&args[0], row_idx)?;
770                        state.accumulate(&value)?;
771                    }
772                }
773            }
774
775            return Ok(state.finalize());
776        }
777
778        // Check old aggregate registry (for non-migrated functions)
779        if self.aggregate_registry.get(&name_upper).is_some() {
780            // Determine which rows to process first
781            let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
782                visible.clone()
783            } else {
784                (0..self.table.rows.len()).collect()
785            };
786
787            // Special handling for STRING_AGG with separator parameter
788            if name_upper == "STRING_AGG" && args.len() >= 2 {
789                // STRING_AGG(column, separator) - without DISTINCT (handled separately)
790                let mut state = crate::sql::aggregates::AggregateState::StringAgg(
791                    // Evaluate the separator (second argument) once
792                    if args.len() >= 2 {
793                        let separator = self.evaluate(&args[1], 0)?; // Separator doesn't depend on row
794                        match separator {
795                            DataValue::String(s) => crate::sql::aggregates::StringAggState::new(&s),
796                            DataValue::InternedString(s) => {
797                                crate::sql::aggregates::StringAggState::new(&s)
798                            }
799                            _ => crate::sql::aggregates::StringAggState::new(","), // Default separator
800                        }
801                    } else {
802                        crate::sql::aggregates::StringAggState::new(",")
803                    },
804                );
805
806                // Evaluate the first argument (column) for each row and accumulate
807                for &row_idx in &rows_to_process {
808                    let value = self.evaluate(&args[0], row_idx)?;
809                    // Now get the aggregate function and accumulate
810                    let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
811                    agg_func.accumulate(&mut state, &value)?;
812                }
813
814                // Finalize the aggregate
815                let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
816                return Ok(agg_func.finalize(state));
817            }
818
819            // Evaluate arguments first if needed (to avoid borrow issues)
820            let values = if !args.is_empty()
821                && !(args.len() == 1
822                    && matches!(&args[0], SqlExpression::Column(c) if c.name == "*"))
823            {
824                // Evaluate the argument expression for each row
825                let mut vals = Vec::new();
826                for &row_idx in &rows_to_process {
827                    let value = self.evaluate(&args[0], row_idx)?;
828                    vals.push(value);
829                }
830                Some(vals)
831            } else {
832                None
833            };
834
835            // Now get the aggregate function and process
836            let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
837            let mut state = agg_func.init();
838
839            if let Some(values) = values {
840                // Use evaluated values (DISTINCT is handled in evaluate_aggregate_with_distinct)
841                for value in &values {
842                    agg_func.accumulate(&mut state, value)?;
843                }
844            } else {
845                // COUNT(*) case
846                for _ in &rows_to_process {
847                    agg_func.accumulate(&mut state, &DataValue::Integer(1))?;
848                }
849            }
850
851            return Ok(agg_func.finalize(state));
852        }
853
854        // First check if this function exists in the registry
855        if self.function_registry.get(name).is_some() {
856            // Evaluate all arguments first to avoid borrow issues
857            let mut evaluated_args = Vec::new();
858            for arg in args {
859                evaluated_args.push(self.evaluate(arg, row_index)?);
860            }
861
862            // Get the function and call it
863            let func = self.function_registry.get(name).unwrap();
864            return func.evaluate(&evaluated_args);
865        }
866
867        // If not in registry, return error for unknown function
868        Err(anyhow!("Unknown function: {}", name))
869    }
870
871    /// Get or create a WindowContext for the given specification
872    /// Public to allow pre-creation of contexts in query engine (optimization)
873    pub fn get_or_create_window_context(
874        &mut self,
875        spec: &WindowSpec,
876    ) -> Result<Arc<WindowContext>> {
877        let overall_start = Instant::now();
878
879        // Create a hash-based key for fast caching (much faster than format!("{:?}", spec))
880        let key = spec.compute_hash();
881
882        if let Some(context) = self.window_contexts.get(&key) {
883            info!(
884                "WindowContext cache hit for spec (lookup: {:.2}μs)",
885                overall_start.elapsed().as_micros()
886            );
887            return Ok(Arc::clone(context));
888        }
889
890        info!("WindowContext cache miss - creating new context");
891        let dataview_start = Instant::now();
892
893        // Create a DataView from the table (with visible rows if filtered)
894        let data_view = if let Some(ref _visible_rows) = self.visible_rows {
895            // Create a filtered view
896            let view = DataView::new(Arc::new(self.table.clone()));
897            // Apply filtering based on visible rows
898            // Note: This is a simplified approach - in production we'd need proper filtering
899            view
900        } else {
901            DataView::new(Arc::new(self.table.clone()))
902        };
903
904        info!(
905            "DataView creation took {:.2}μs",
906            dataview_start.elapsed().as_micros()
907        );
908        let context_start = Instant::now();
909
910        // Create the WindowContext with the full spec (including frame)
911        let context = WindowContext::new_with_spec(Arc::new(data_view), spec.clone())?;
912
913        info!(
914            "WindowContext::new_with_spec took {:.2}ms (rows: {})",
915            context_start.elapsed().as_secs_f64() * 1000.0,
916            self.table.row_count()
917        );
918
919        let context = Arc::new(context);
920        self.window_contexts.insert(key, Arc::clone(&context));
921
922        info!(
923            "Total WindowContext creation (cache miss) took {:.2}ms",
924            overall_start.elapsed().as_secs_f64() * 1000.0
925        );
926
927        Ok(context)
928    }
929
930    /// Evaluate a window function
931    fn evaluate_window_function(
932        &mut self,
933        name: &str,
934        args: &[SqlExpression],
935        spec: &WindowSpec,
936        row_index: usize,
937    ) -> Result<DataValue> {
938        let func_start = Instant::now();
939        let name_upper = name.to_uppercase();
940
941        // First check if this is a syntactic sugar function in the registry
942        debug!("Looking for window function {} in registry", name_upper);
943        if let Some(window_fn_arc) = self.window_function_registry.get(&name_upper) {
944            debug!("Found window function {} in registry", name_upper);
945
946            // Dereference to get the actual window function
947            let window_fn = window_fn_arc.as_ref();
948
949            // Validate arguments
950            window_fn.validate_args(args)?;
951
952            // Transform the window spec based on the function's requirements
953            let transformed_spec = window_fn.transform_window_spec(spec, args)?;
954
955            // Get or create the window context with the transformed spec
956            let context = self.get_or_create_window_context(&transformed_spec)?;
957
958            // Create an expression evaluator adapter
959            struct EvaluatorAdapter<'a, 'b> {
960                evaluator: &'a mut ArithmeticEvaluator<'b>,
961                row_index: usize,
962            }
963
964            impl<'a, 'b> ExpressionEvaluator for EvaluatorAdapter<'a, 'b> {
965                fn evaluate(
966                    &mut self,
967                    expr: &SqlExpression,
968                    row_index: usize,
969                ) -> Result<DataValue> {
970                    self.evaluator.evaluate(expr, row_index)
971                }
972            }
973
974            let mut adapter = EvaluatorAdapter {
975                evaluator: self,
976                row_index,
977            };
978
979            let compute_start = Instant::now();
980            // Call the window function's compute method
981            let result = window_fn.compute(&context, row_index, args, &mut adapter);
982
983            info!(
984                "{} (registry) evaluation: total={:.2}μs, compute={:.2}μs",
985                name_upper,
986                func_start.elapsed().as_micros(),
987                compute_start.elapsed().as_micros()
988            );
989
990            return result;
991        }
992
993        // Fall back to built-in window functions
994        let context_start = Instant::now();
995        let context = self.get_or_create_window_context(spec)?;
996        let context_time = context_start.elapsed();
997
998        let eval_start = Instant::now();
999
1000        let result = match name_upper.as_str() {
1001            "LAG" => {
1002                // LAG(column, offset, default)
1003                if args.is_empty() {
1004                    return Err(anyhow!("LAG requires at least 1 argument"));
1005                }
1006
1007                // Get column name
1008                let column = match &args[0] {
1009                    SqlExpression::Column(col) => col.clone(),
1010                    _ => return Err(anyhow!("LAG first argument must be a column")),
1011                };
1012
1013                // Get offset (default 1)
1014                let offset = if args.len() > 1 {
1015                    match self.evaluate(&args[1], row_index)? {
1016                        DataValue::Integer(i) => i as i32,
1017                        _ => return Err(anyhow!("LAG offset must be an integer")),
1018                    }
1019                } else {
1020                    1
1021                };
1022
1023                let offset_start = Instant::now();
1024                // Get value at offset
1025                let value = context
1026                    .get_offset_value(row_index, -offset, &column.name)
1027                    .unwrap_or(DataValue::Null);
1028
1029                debug!(
1030                    "LAG offset access took {:.2}μs (offset={})",
1031                    offset_start.elapsed().as_micros(),
1032                    offset
1033                );
1034
1035                Ok(value)
1036            }
1037            "LEAD" => {
1038                // LEAD(column, offset, default)
1039                if args.is_empty() {
1040                    return Err(anyhow!("LEAD requires at least 1 argument"));
1041                }
1042
1043                // Get column name
1044                let column = match &args[0] {
1045                    SqlExpression::Column(col) => col.clone(),
1046                    _ => return Err(anyhow!("LEAD first argument must be a column")),
1047                };
1048
1049                // Get offset (default 1)
1050                let offset = if args.len() > 1 {
1051                    match self.evaluate(&args[1], row_index)? {
1052                        DataValue::Integer(i) => i as i32,
1053                        _ => return Err(anyhow!("LEAD offset must be an integer")),
1054                    }
1055                } else {
1056                    1
1057                };
1058
1059                let offset_start = Instant::now();
1060                // Get value at offset
1061                let value = context
1062                    .get_offset_value(row_index, offset, &column.name)
1063                    .unwrap_or(DataValue::Null);
1064
1065                debug!(
1066                    "LEAD offset access took {:.2}μs (offset={})",
1067                    offset_start.elapsed().as_micros(),
1068                    offset
1069                );
1070
1071                Ok(value)
1072            }
1073            "ROW_NUMBER" => {
1074                // ROW_NUMBER() - no arguments
1075                Ok(DataValue::Integer(context.get_row_number(row_index) as i64))
1076            }
1077            "RANK" => {
1078                // RANK() - no arguments
1079                Ok(DataValue::Integer(context.get_rank(row_index)))
1080            }
1081            "DENSE_RANK" => {
1082                // DENSE_RANK() - no arguments
1083                Ok(DataValue::Integer(context.get_dense_rank(row_index)))
1084            }
1085            "FIRST_VALUE" => {
1086                // FIRST_VALUE(column) OVER (... ROWS ...)
1087                if args.is_empty() {
1088                    return Err(anyhow!("FIRST_VALUE requires 1 argument"));
1089                }
1090
1091                let column = match &args[0] {
1092                    SqlExpression::Column(col) => col.clone(),
1093                    _ => return Err(anyhow!("FIRST_VALUE argument must be a column")),
1094                };
1095
1096                // Use frame-aware version if frame is specified
1097                if context.has_frame() {
1098                    Ok(context
1099                        .get_frame_first_value(row_index, &column.name)
1100                        .unwrap_or(DataValue::Null))
1101                } else {
1102                    Ok(context
1103                        .get_first_value(row_index, &column.name)
1104                        .unwrap_or(DataValue::Null))
1105                }
1106            }
1107            "LAST_VALUE" => {
1108                // LAST_VALUE(column) OVER (... ROWS ...)
1109                if args.is_empty() {
1110                    return Err(anyhow!("LAST_VALUE requires 1 argument"));
1111                }
1112
1113                let column = match &args[0] {
1114                    SqlExpression::Column(col) => col.clone(),
1115                    _ => return Err(anyhow!("LAST_VALUE argument must be a column")),
1116                };
1117
1118                // Use frame-aware version if frame is specified
1119                if context.has_frame() {
1120                    Ok(context
1121                        .get_frame_last_value(row_index, &column.name)
1122                        .unwrap_or(DataValue::Null))
1123                } else {
1124                    Ok(context
1125                        .get_last_value(row_index, &column.name)
1126                        .unwrap_or(DataValue::Null))
1127                }
1128            }
1129            "SUM" => {
1130                // SUM(column) OVER (PARTITION BY ... ROWS n PRECEDING)
1131                if args.is_empty() {
1132                    return Err(anyhow!("SUM requires 1 argument"));
1133                }
1134
1135                let column = match &args[0] {
1136                    SqlExpression::Column(col) => col.clone(),
1137                    _ => return Err(anyhow!("SUM argument must be a column")),
1138                };
1139
1140                // Use frame-aware sum if frame is specified, otherwise use partition sum
1141                if context.has_frame() {
1142                    Ok(context
1143                        .get_frame_sum(row_index, &column.name)
1144                        .unwrap_or(DataValue::Null))
1145                } else {
1146                    Ok(context
1147                        .get_partition_sum(row_index, &column.name)
1148                        .unwrap_or(DataValue::Null))
1149                }
1150            }
1151            "AVG" => {
1152                // AVG(column) OVER (PARTITION BY ... ROWS n PRECEDING)
1153                if args.is_empty() {
1154                    return Err(anyhow!("AVG requires 1 argument"));
1155                }
1156
1157                let column = match &args[0] {
1158                    SqlExpression::Column(col) => col.clone(),
1159                    _ => return Err(anyhow!("AVG argument must be a column")),
1160                };
1161
1162                // Use frame-aware avg if frame is specified, otherwise use partition avg
1163                if context.has_frame() {
1164                    Ok(context
1165                        .get_frame_avg(row_index, &column.name)
1166                        .unwrap_or(DataValue::Null))
1167                } else {
1168                    Ok(context
1169                        .get_partition_avg(row_index, &column.name)
1170                        .unwrap_or(DataValue::Null))
1171                }
1172            }
1173            "STDDEV" | "STDEV" => {
1174                // STDDEV(column) OVER (PARTITION BY ... ROWS n PRECEDING)
1175                if args.is_empty() {
1176                    return Err(anyhow!("STDDEV requires 1 argument"));
1177                }
1178
1179                let column = match &args[0] {
1180                    SqlExpression::Column(col) => col.clone(),
1181                    _ => return Err(anyhow!("STDDEV argument must be a column")),
1182                };
1183
1184                Ok(context
1185                    .get_frame_stddev(row_index, &column.name)
1186                    .unwrap_or(DataValue::Null))
1187            }
1188            "VARIANCE" | "VAR" => {
1189                // VARIANCE(column) OVER (PARTITION BY ... ROWS n PRECEDING)
1190                if args.is_empty() {
1191                    return Err(anyhow!("VARIANCE requires 1 argument"));
1192                }
1193
1194                let column = match &args[0] {
1195                    SqlExpression::Column(col) => col.clone(),
1196                    _ => return Err(anyhow!("VARIANCE argument must be a column")),
1197                };
1198
1199                Ok(context
1200                    .get_frame_variance(row_index, &column.name)
1201                    .unwrap_or(DataValue::Null))
1202            }
1203            "MIN" => {
1204                // MIN(column) OVER (PARTITION BY ... ROWS n PRECEDING)
1205                if args.is_empty() {
1206                    return Err(anyhow!("MIN requires 1 argument"));
1207                }
1208
1209                let column = match &args[0] {
1210                    SqlExpression::Column(col) => col.clone(),
1211                    _ => return Err(anyhow!("MIN argument must be a column")),
1212                };
1213
1214                let frame_rows = context.get_frame_rows(row_index);
1215                if frame_rows.is_empty() {
1216                    return Ok(DataValue::Null);
1217                }
1218
1219                let source_table = context.source();
1220                let col_idx = source_table
1221                    .get_column_index(&column.name)
1222                    .ok_or_else(|| anyhow!("Column '{}' not found", column.name))?;
1223
1224                let mut min_value: Option<DataValue> = None;
1225                for &row_idx in &frame_rows {
1226                    if let Some(value) = source_table.get_value(row_idx, col_idx) {
1227                        if !matches!(value, DataValue::Null) {
1228                            match &min_value {
1229                                None => min_value = Some(value.clone()),
1230                                Some(current_min) => {
1231                                    if value < current_min {
1232                                        min_value = Some(value.clone());
1233                                    }
1234                                }
1235                            }
1236                        }
1237                    }
1238                }
1239
1240                Ok(min_value.unwrap_or(DataValue::Null))
1241            }
1242            "MAX" => {
1243                // MAX(column) OVER (PARTITION BY ... ROWS n PRECEDING)
1244                if args.is_empty() {
1245                    return Err(anyhow!("MAX requires 1 argument"));
1246                }
1247
1248                let column = match &args[0] {
1249                    SqlExpression::Column(col) => col.clone(),
1250                    _ => return Err(anyhow!("MAX argument must be a column")),
1251                };
1252
1253                let frame_rows = context.get_frame_rows(row_index);
1254                if frame_rows.is_empty() {
1255                    return Ok(DataValue::Null);
1256                }
1257
1258                let source_table = context.source();
1259                let col_idx = source_table
1260                    .get_column_index(&column.name)
1261                    .ok_or_else(|| anyhow!("Column '{}' not found", column.name))?;
1262
1263                let mut max_value: Option<DataValue> = None;
1264                for &row_idx in &frame_rows {
1265                    if let Some(value) = source_table.get_value(row_idx, col_idx) {
1266                        if !matches!(value, DataValue::Null) {
1267                            match &max_value {
1268                                None => max_value = Some(value.clone()),
1269                                Some(current_max) => {
1270                                    if value > current_max {
1271                                        max_value = Some(value.clone());
1272                                    }
1273                                }
1274                            }
1275                        }
1276                    }
1277                }
1278
1279                Ok(max_value.unwrap_or(DataValue::Null))
1280            }
1281            "COUNT" => {
1282                // COUNT(*) or COUNT(column) OVER (PARTITION BY ... ROWS n PRECEDING)
1283                // Use frame-aware count if frame is specified, otherwise use partition count
1284
1285                if args.is_empty() {
1286                    // COUNT(*) OVER (...)
1287                    if context.has_frame() {
1288                        Ok(context
1289                            .get_frame_count(row_index, None)
1290                            .unwrap_or(DataValue::Null))
1291                    } else {
1292                        Ok(context
1293                            .get_partition_count(row_index, None)
1294                            .unwrap_or(DataValue::Null))
1295                    }
1296                } else {
1297                    // Check for COUNT(*)
1298                    let column = match &args[0] {
1299                        SqlExpression::Column(col) => {
1300                            if col.name == "*" {
1301                                // COUNT(*) - count all rows
1302                                if context.has_frame() {
1303                                    return Ok(context
1304                                        .get_frame_count(row_index, None)
1305                                        .unwrap_or(DataValue::Null));
1306                                } else {
1307                                    return Ok(context
1308                                        .get_partition_count(row_index, None)
1309                                        .unwrap_or(DataValue::Null));
1310                                }
1311                            }
1312                            col.clone()
1313                        }
1314                        SqlExpression::StringLiteral(s) if s == "*" => {
1315                            // COUNT(*) as StringLiteral
1316                            if context.has_frame() {
1317                                return Ok(context
1318                                    .get_frame_count(row_index, None)
1319                                    .unwrap_or(DataValue::Null));
1320                            } else {
1321                                return Ok(context
1322                                    .get_partition_count(row_index, None)
1323                                    .unwrap_or(DataValue::Null));
1324                            }
1325                        }
1326                        _ => return Err(anyhow!("COUNT argument must be a column or *")),
1327                    };
1328
1329                    // COUNT(column) - count non-null values
1330                    if context.has_frame() {
1331                        Ok(context
1332                            .get_frame_count(row_index, Some(&column.name))
1333                            .unwrap_or(DataValue::Null))
1334                    } else {
1335                        Ok(context
1336                            .get_partition_count(row_index, Some(&column.name))
1337                            .unwrap_or(DataValue::Null))
1338                    }
1339                }
1340            }
1341            _ => Err(anyhow!("Unknown window function: {}", name)),
1342        };
1343
1344        let eval_time = eval_start.elapsed();
1345
1346        info!(
1347            "{} (built-in) evaluation: total={:.2}μs, context={:.2}μs, eval={:.2}μs",
1348            name_upper,
1349            func_start.elapsed().as_micros(),
1350            context_time.as_micros(),
1351            eval_time.as_micros()
1352        );
1353
1354        result
1355    }
1356
1357    /// Evaluate a method call on a column (e.g., `column.Trim()`)
1358    fn evaluate_method_call(
1359        &mut self,
1360        object: &str,
1361        method: &str,
1362        args: &[SqlExpression],
1363        row_index: usize,
1364    ) -> Result<DataValue> {
1365        // Get column value
1366        let col_index = self.table.get_column_index(object).ok_or_else(|| {
1367            let suggestion = self.find_similar_column(object);
1368            match suggestion {
1369                Some(similar) => {
1370                    anyhow!("Column '{}' not found. Did you mean '{}'?", object, similar)
1371                }
1372                None => anyhow!("Column '{}' not found", object),
1373            }
1374        })?;
1375
1376        let cell_value = self.table.get_value(row_index, col_index).cloned();
1377
1378        self.evaluate_method_on_value(
1379            &cell_value.unwrap_or(DataValue::Null),
1380            method,
1381            args,
1382            row_index,
1383        )
1384    }
1385
1386    /// Evaluate a method on a value
1387    fn evaluate_method_on_value(
1388        &mut self,
1389        value: &DataValue,
1390        method: &str,
1391        args: &[SqlExpression],
1392        row_index: usize,
1393    ) -> Result<DataValue> {
1394        // First, try to proxy the method through the function registry
1395        // Many string methods have corresponding functions (TRIM, LENGTH, CONTAINS, etc.)
1396
1397        // Map method names to function names (case-insensitive matching)
1398        let function_name = match method.to_lowercase().as_str() {
1399            "trim" => "TRIM",
1400            "trimstart" | "trimbegin" => "TRIMSTART",
1401            "trimend" => "TRIMEND",
1402            "length" | "len" => "LENGTH",
1403            "contains" => "CONTAINS",
1404            "startswith" => "STARTSWITH",
1405            "endswith" => "ENDSWITH",
1406            "indexof" => "INDEXOF",
1407            _ => method, // Try the method name as-is
1408        };
1409
1410        // Check if we have this function in the registry
1411        if self.function_registry.get(function_name).is_some() {
1412            debug!(
1413                "Proxying method '{}' through function registry as '{}'",
1414                method, function_name
1415            );
1416
1417            // Prepare arguments: receiver is the first argument, followed by method args
1418            let mut func_args = vec![value.clone()];
1419
1420            // Evaluate method arguments and add them
1421            for arg in args {
1422                func_args.push(self.evaluate(arg, row_index)?);
1423            }
1424
1425            // Get the function and call it
1426            let func = self.function_registry.get(function_name).unwrap();
1427            return func.evaluate(&func_args);
1428        }
1429
1430        // If not in registry, the method is not supported
1431        // All methods should be registered in the function registry
1432        Err(anyhow!(
1433            "Method '{}' not found. It should be registered in the function registry.",
1434            method
1435        ))
1436    }
1437
1438    /// Evaluate a CASE expression
1439    fn evaluate_case_expression(
1440        &mut self,
1441        when_branches: &[crate::sql::recursive_parser::WhenBranch],
1442        else_branch: &Option<Box<SqlExpression>>,
1443        row_index: usize,
1444    ) -> Result<DataValue> {
1445        debug!(
1446            "ArithmeticEvaluator: evaluating CASE expression for row {}",
1447            row_index
1448        );
1449
1450        // Evaluate each WHEN condition in order
1451        for branch in when_branches {
1452            // Evaluate the condition as a boolean
1453            let condition_result = self.evaluate_condition_as_bool(&branch.condition, row_index)?;
1454
1455            if condition_result {
1456                debug!("CASE: WHEN condition matched, evaluating result expression");
1457                return self.evaluate(&branch.result, row_index);
1458            }
1459        }
1460
1461        // If no WHEN condition matched, evaluate ELSE clause (or return NULL)
1462        if let Some(else_expr) = else_branch {
1463            debug!("CASE: No WHEN matched, evaluating ELSE expression");
1464            self.evaluate(else_expr, row_index)
1465        } else {
1466            debug!("CASE: No WHEN matched and no ELSE, returning NULL");
1467            Ok(DataValue::Null)
1468        }
1469    }
1470
1471    /// Evaluate a simple CASE expression
1472    fn evaluate_simple_case_expression(
1473        &mut self,
1474        expr: &Box<SqlExpression>,
1475        when_branches: &[crate::sql::parser::ast::SimpleWhenBranch],
1476        else_branch: &Option<Box<SqlExpression>>,
1477        row_index: usize,
1478    ) -> Result<DataValue> {
1479        debug!(
1480            "ArithmeticEvaluator: evaluating simple CASE expression for row {}",
1481            row_index
1482        );
1483
1484        // Evaluate the main expression once
1485        let case_value = self.evaluate(expr, row_index)?;
1486        debug!("Simple CASE: evaluated expression to {:?}", case_value);
1487
1488        // Compare against each WHEN value in order
1489        for branch in when_branches {
1490            // Evaluate the WHEN value
1491            let when_value = self.evaluate(&branch.value, row_index)?;
1492
1493            // Check for equality
1494            if self.values_equal(&case_value, &when_value)? {
1495                debug!("Simple CASE: WHEN value matched, evaluating result expression");
1496                return self.evaluate(&branch.result, row_index);
1497            }
1498        }
1499
1500        // If no WHEN value matched, evaluate ELSE clause (or return NULL)
1501        if let Some(else_expr) = else_branch {
1502            debug!("Simple CASE: No WHEN matched, evaluating ELSE expression");
1503            self.evaluate(else_expr, row_index)
1504        } else {
1505            debug!("Simple CASE: No WHEN matched and no ELSE, returning NULL");
1506            Ok(DataValue::Null)
1507        }
1508    }
1509
1510    /// Check if two DataValues are equal
1511    fn values_equal(&self, left: &DataValue, right: &DataValue) -> Result<bool> {
1512        match (left, right) {
1513            (DataValue::Null, DataValue::Null) => Ok(true),
1514            (DataValue::Null, _) | (_, DataValue::Null) => Ok(false),
1515            (DataValue::Integer(a), DataValue::Integer(b)) => Ok(a == b),
1516            (DataValue::Float(a), DataValue::Float(b)) => Ok((a - b).abs() < f64::EPSILON),
1517            (DataValue::String(a), DataValue::String(b)) => Ok(a == b),
1518            (DataValue::Boolean(a), DataValue::Boolean(b)) => Ok(a == b),
1519            (DataValue::DateTime(a), DataValue::DateTime(b)) => Ok(a == b),
1520            // Type coercion for numeric comparisons
1521            (DataValue::Integer(a), DataValue::Float(b)) => {
1522                Ok((*a as f64 - b).abs() < f64::EPSILON)
1523            }
1524            (DataValue::Float(a), DataValue::Integer(b)) => {
1525                Ok((a - *b as f64).abs() < f64::EPSILON)
1526            }
1527            _ => Ok(false),
1528        }
1529    }
1530
1531    /// Helper method to evaluate an expression as a boolean (for CASE WHEN conditions)
1532    fn evaluate_condition_as_bool(
1533        &mut self,
1534        expr: &SqlExpression,
1535        row_index: usize,
1536    ) -> Result<bool> {
1537        let value = self.evaluate(expr, row_index)?;
1538
1539        match value {
1540            DataValue::Boolean(b) => Ok(b),
1541            DataValue::Integer(i) => Ok(i != 0),
1542            DataValue::Float(f) => Ok(f != 0.0),
1543            DataValue::Null => Ok(false),
1544            DataValue::String(s) => Ok(!s.is_empty()),
1545            DataValue::InternedString(s) => Ok(!s.is_empty()),
1546            _ => Ok(true), // Other types are considered truthy
1547        }
1548    }
1549
1550    /// Evaluate a DATETIME constructor expression
1551    fn evaluate_datetime_constructor(
1552        &self,
1553        year: i32,
1554        month: u32,
1555        day: u32,
1556        hour: Option<u32>,
1557        minute: Option<u32>,
1558        second: Option<u32>,
1559    ) -> Result<DataValue> {
1560        use chrono::{NaiveDate, TimeZone, Utc};
1561
1562        // Create a NaiveDate
1563        let date = NaiveDate::from_ymd_opt(year, month, day)
1564            .ok_or_else(|| anyhow!("Invalid date: {}-{}-{}", year, month, day))?;
1565
1566        // Create datetime with provided time components or defaults
1567        let hour = hour.unwrap_or(0);
1568        let minute = minute.unwrap_or(0);
1569        let second = second.unwrap_or(0);
1570
1571        let naive_datetime = date
1572            .and_hms_opt(hour, minute, second)
1573            .ok_or_else(|| anyhow!("Invalid time: {}:{}:{}", hour, minute, second))?;
1574
1575        // Convert to UTC DateTime
1576        let datetime = Utc.from_utc_datetime(&naive_datetime);
1577
1578        // Format as string with milliseconds
1579        let datetime_str = datetime.format("%Y-%m-%d %H:%M:%S%.3f").to_string();
1580        Ok(DataValue::String(datetime_str))
1581    }
1582
1583    /// Evaluate a DATETIME.TODAY constructor expression
1584    fn evaluate_datetime_today(
1585        &self,
1586        hour: Option<u32>,
1587        minute: Option<u32>,
1588        second: Option<u32>,
1589    ) -> Result<DataValue> {
1590        use chrono::{TimeZone, Utc};
1591
1592        // Get today's date in UTC
1593        let today = Utc::now().date_naive();
1594
1595        // Create datetime with provided time components or defaults
1596        let hour = hour.unwrap_or(0);
1597        let minute = minute.unwrap_or(0);
1598        let second = second.unwrap_or(0);
1599
1600        let naive_datetime = today
1601            .and_hms_opt(hour, minute, second)
1602            .ok_or_else(|| anyhow!("Invalid time: {}:{}:{}", hour, minute, second))?;
1603
1604        // Convert to UTC DateTime
1605        let datetime = Utc.from_utc_datetime(&naive_datetime);
1606
1607        // Format as string with milliseconds
1608        let datetime_str = datetime.format("%Y-%m-%d %H:%M:%S%.3f").to_string();
1609        Ok(DataValue::String(datetime_str))
1610    }
1611}
1612
1613#[cfg(test)]
1614mod tests {
1615    use super::*;
1616    use crate::data::datatable::{DataColumn, DataRow};
1617
1618    fn create_test_table() -> DataTable {
1619        let mut table = DataTable::new("test");
1620        table.add_column(DataColumn::new("a"));
1621        table.add_column(DataColumn::new("b"));
1622        table.add_column(DataColumn::new("c"));
1623
1624        table
1625            .add_row(DataRow::new(vec![
1626                DataValue::Integer(10),
1627                DataValue::Float(2.5),
1628                DataValue::Integer(4),
1629            ]))
1630            .unwrap();
1631
1632        table
1633    }
1634
1635    #[test]
1636    fn test_evaluate_column() {
1637        let table = create_test_table();
1638        let mut evaluator = ArithmeticEvaluator::new(&table);
1639
1640        let expr = SqlExpression::Column(ColumnRef::unquoted("a".to_string()));
1641        let result = evaluator.evaluate(&expr, 0).unwrap();
1642        assert_eq!(result, DataValue::Integer(10));
1643    }
1644
1645    #[test]
1646    fn test_evaluate_between_column_in_range() {
1647        let table = create_test_table();
1648        let mut evaluator = ArithmeticEvaluator::new(&table);
1649
1650        // column 'a' is 10 — 5 <= 10 <= 20 is true
1651        let expr = SqlExpression::Between {
1652            expr: Box::new(SqlExpression::Column(ColumnRef::unquoted("a".to_string()))),
1653            lower: Box::new(SqlExpression::NumberLiteral("5".to_string())),
1654            upper: Box::new(SqlExpression::NumberLiteral("20".to_string())),
1655        };
1656        assert_eq!(
1657            evaluator.evaluate(&expr, 0).unwrap(),
1658            DataValue::Boolean(true)
1659        );
1660    }
1661
1662    #[test]
1663    fn test_evaluate_between_column_out_of_range() {
1664        let table = create_test_table();
1665        let mut evaluator = ArithmeticEvaluator::new(&table);
1666
1667        // column 'a' is 10 — 11 <= 10 <= 20 is false
1668        let expr = SqlExpression::Between {
1669            expr: Box::new(SqlExpression::Column(ColumnRef::unquoted("a".to_string()))),
1670            lower: Box::new(SqlExpression::NumberLiteral("11".to_string())),
1671            upper: Box::new(SqlExpression::NumberLiteral("20".to_string())),
1672        };
1673        assert_eq!(
1674            evaluator.evaluate(&expr, 0).unwrap(),
1675            DataValue::Boolean(false)
1676        );
1677    }
1678
1679    #[test]
1680    fn test_evaluate_between_endpoints_inclusive() {
1681        let table = create_test_table();
1682        let mut evaluator = ArithmeticEvaluator::new(&table);
1683
1684        // column 'a' is 10 — 10 <= 10 <= 10 is true (both endpoints inclusive)
1685        let expr = SqlExpression::Between {
1686            expr: Box::new(SqlExpression::Column(ColumnRef::unquoted("a".to_string()))),
1687            lower: Box::new(SqlExpression::NumberLiteral("10".to_string())),
1688            upper: Box::new(SqlExpression::NumberLiteral("10".to_string())),
1689        };
1690        assert_eq!(
1691            evaluator.evaluate(&expr, 0).unwrap(),
1692            DataValue::Boolean(true)
1693        );
1694    }
1695
1696    #[test]
1697    fn test_evaluate_number_literal() {
1698        let table = create_test_table();
1699        let mut evaluator = ArithmeticEvaluator::new(&table);
1700
1701        let expr = SqlExpression::NumberLiteral("42".to_string());
1702        let result = evaluator.evaluate(&expr, 0).unwrap();
1703        assert_eq!(result, DataValue::Integer(42));
1704
1705        let expr = SqlExpression::NumberLiteral("3.14".to_string());
1706        let result = evaluator.evaluate(&expr, 0).unwrap();
1707        assert_eq!(result, DataValue::Float(3.14));
1708    }
1709
1710    #[test]
1711    fn test_add_values() {
1712        let table = create_test_table();
1713        let mut evaluator = ArithmeticEvaluator::new(&table);
1714
1715        // Integer + Integer
1716        let result = evaluator
1717            .add_values(&DataValue::Integer(5), &DataValue::Integer(3))
1718            .unwrap();
1719        assert_eq!(result, DataValue::Integer(8));
1720
1721        // Integer + Float
1722        let result = evaluator
1723            .add_values(&DataValue::Integer(5), &DataValue::Float(2.5))
1724            .unwrap();
1725        assert_eq!(result, DataValue::Float(7.5));
1726    }
1727
1728    #[test]
1729    fn test_multiply_values() {
1730        let table = create_test_table();
1731        let mut evaluator = ArithmeticEvaluator::new(&table);
1732
1733        // Integer * Float
1734        let result = evaluator
1735            .multiply_values(&DataValue::Integer(4), &DataValue::Float(2.5))
1736            .unwrap();
1737        assert_eq!(result, DataValue::Float(10.0));
1738    }
1739
1740    #[test]
1741    fn test_divide_values() {
1742        let table = create_test_table();
1743        let mut evaluator = ArithmeticEvaluator::new(&table);
1744
1745        // Exact division
1746        let result = evaluator
1747            .divide_values(&DataValue::Integer(10), &DataValue::Integer(2))
1748            .unwrap();
1749        assert_eq!(result, DataValue::Integer(5));
1750
1751        // Non-exact division
1752        let result = evaluator
1753            .divide_values(&DataValue::Integer(10), &DataValue::Integer(3))
1754            .unwrap();
1755        assert_eq!(result, DataValue::Float(10.0 / 3.0));
1756    }
1757
1758    #[test]
1759    fn test_division_by_zero() {
1760        let table = create_test_table();
1761        let mut evaluator = ArithmeticEvaluator::new(&table);
1762
1763        let result = evaluator.divide_values(&DataValue::Integer(10), &DataValue::Integer(0));
1764        assert!(result.is_err());
1765        assert!(result.unwrap_err().to_string().contains("Division by zero"));
1766    }
1767
1768    #[test]
1769    fn test_binary_op_expression() {
1770        let table = create_test_table();
1771        let mut evaluator = ArithmeticEvaluator::new(&table);
1772
1773        // a * b where a=10, b=2.5
1774        let expr = SqlExpression::BinaryOp {
1775            left: Box::new(SqlExpression::Column(ColumnRef::unquoted("a".to_string()))),
1776            op: "*".to_string(),
1777            right: Box::new(SqlExpression::Column(ColumnRef::unquoted("b".to_string()))),
1778        };
1779
1780        let result = evaluator.evaluate(&expr, 0).unwrap();
1781        assert_eq!(result, DataValue::Float(25.0));
1782    }
1783}