sql_cli/data/
arithmetic_evaluator.rs

1use crate::config::global::get_date_notation;
2use crate::data::data_view::DataView;
3use crate::data::datatable::{DataTable, DataValue};
4use crate::data::value_comparisons::compare_with_op;
5use crate::sql::aggregates::AggregateRegistry;
6use crate::sql::functions::FunctionRegistry;
7use crate::sql::parser::ast::WindowSpec;
8use crate::sql::recursive_parser::SqlExpression;
9use crate::sql::window_context::WindowContext;
10use crate::sql::window_functions::{ExpressionEvaluator, WindowFunctionRegistry};
11use anyhow::{anyhow, Result};
12use std::collections::{HashMap, HashSet};
13use std::sync::Arc;
14use tracing::debug;
15
16/// Evaluates SQL expressions to compute `DataValues` (for SELECT clauses)
17/// This is different from `RecursiveWhereEvaluator` which returns boolean
18pub struct ArithmeticEvaluator<'a> {
19    table: &'a DataTable,
20    date_notation: String,
21    function_registry: Arc<FunctionRegistry>,
22    aggregate_registry: Arc<AggregateRegistry>,
23    window_function_registry: Arc<WindowFunctionRegistry>,
24    visible_rows: Option<Vec<usize>>, // For aggregate functions on filtered views
25    window_contexts: HashMap<String, Arc<WindowContext>>, // Cache window contexts by spec
26    table_aliases: HashMap<String, String>, // Map alias -> table name for qualified columns
27}
28
29impl<'a> ArithmeticEvaluator<'a> {
30    #[must_use]
31    pub fn new(table: &'a DataTable) -> Self {
32        Self {
33            table,
34            date_notation: get_date_notation(),
35            function_registry: Arc::new(FunctionRegistry::new()),
36            aggregate_registry: Arc::new(AggregateRegistry::new()),
37            window_function_registry: Arc::new(WindowFunctionRegistry::new()),
38            visible_rows: None,
39            window_contexts: HashMap::new(),
40            table_aliases: HashMap::new(),
41        }
42    }
43
44    #[must_use]
45    pub fn with_date_notation(table: &'a DataTable, date_notation: String) -> Self {
46        Self {
47            table,
48            date_notation,
49            function_registry: Arc::new(FunctionRegistry::new()),
50            aggregate_registry: Arc::new(AggregateRegistry::new()),
51            window_function_registry: Arc::new(WindowFunctionRegistry::new()),
52            visible_rows: None,
53            window_contexts: HashMap::new(),
54            table_aliases: HashMap::new(),
55        }
56    }
57
58    /// Set visible rows for aggregate functions (for filtered views)
59    #[must_use]
60    pub fn with_visible_rows(mut self, rows: Vec<usize>) -> Self {
61        self.visible_rows = Some(rows);
62        self
63    }
64
65    /// Set table aliases for qualified column resolution
66    #[must_use]
67    pub fn with_table_aliases(mut self, aliases: HashMap<String, String>) -> Self {
68        self.table_aliases = aliases;
69        self
70    }
71
72    #[must_use]
73    pub fn with_date_notation_and_registry(
74        table: &'a DataTable,
75        date_notation: String,
76        function_registry: Arc<FunctionRegistry>,
77    ) -> Self {
78        Self {
79            table,
80            date_notation,
81            function_registry,
82            aggregate_registry: Arc::new(AggregateRegistry::new()),
83            window_function_registry: Arc::new(WindowFunctionRegistry::new()),
84            visible_rows: None,
85            window_contexts: HashMap::new(),
86            table_aliases: HashMap::new(),
87        }
88    }
89
90    /// Find a column name similar to the given name using edit distance
91    fn find_similar_column(&self, name: &str) -> Option<String> {
92        let columns = self.table.column_names();
93        let mut best_match: Option<(String, usize)> = None;
94
95        for col in columns {
96            let distance = self.edit_distance(&col.to_lowercase(), &name.to_lowercase());
97            // Only suggest if distance is small (likely a typo)
98            // Allow up to 3 edits for longer names
99            let max_distance = if name.len() > 10 { 3 } else { 2 };
100            if distance <= max_distance {
101                match &best_match {
102                    None => best_match = Some((col, distance)),
103                    Some((_, best_dist)) if distance < *best_dist => {
104                        best_match = Some((col, distance));
105                    }
106                    _ => {}
107                }
108            }
109        }
110
111        best_match.map(|(name, _)| name)
112    }
113
114    /// Calculate Levenshtein edit distance between two strings
115    fn edit_distance(&self, s1: &str, s2: &str) -> usize {
116        // Use the shared implementation from string_methods
117        crate::sql::functions::string_methods::EditDistanceFunction::calculate_edit_distance(s1, s2)
118    }
119
120    /// Evaluate an SQL expression to produce a `DataValue`
121    pub fn evaluate(&mut self, expr: &SqlExpression, row_index: usize) -> Result<DataValue> {
122        debug!(
123            "ArithmeticEvaluator: evaluating {:?} for row {}",
124            expr, row_index
125        );
126
127        match expr {
128            SqlExpression::Column(column_name) => self.evaluate_column(column_name, row_index),
129            SqlExpression::StringLiteral(s) => Ok(DataValue::String(s.clone())),
130            SqlExpression::BooleanLiteral(b) => Ok(DataValue::Boolean(*b)),
131            SqlExpression::NumberLiteral(n) => self.evaluate_number_literal(n),
132            SqlExpression::Null => Ok(DataValue::Null),
133            SqlExpression::BinaryOp { left, op, right } => {
134                self.evaluate_binary_op(left, op, right, row_index)
135            }
136            SqlExpression::FunctionCall {
137                name,
138                args,
139                distinct,
140            } => self.evaluate_function_with_distinct(name, args, *distinct, row_index),
141            SqlExpression::WindowFunction {
142                name,
143                args,
144                window_spec,
145            } => self.evaluate_window_function(name, args, window_spec, row_index),
146            SqlExpression::MethodCall {
147                object,
148                method,
149                args,
150            } => self.evaluate_method_call(object, method, args, row_index),
151            SqlExpression::ChainedMethodCall { base, method, args } => {
152                // Evaluate the base expression first, then apply the method
153                let base_value = self.evaluate(base, row_index)?;
154                self.evaluate_method_on_value(&base_value, method, args, row_index)
155            }
156            SqlExpression::CaseExpression {
157                when_branches,
158                else_branch,
159            } => self.evaluate_case_expression(when_branches, else_branch, row_index),
160            _ => Err(anyhow!(
161                "Unsupported expression type for arithmetic evaluation: {:?}",
162                expr
163            )),
164        }
165    }
166
167    /// Evaluate a column reference
168    fn evaluate_column(&self, column_name: &str, row_index: usize) -> Result<DataValue> {
169        // First try to resolve qualified column names (table.column or alias.column)
170        let resolved_column = if column_name.contains('.') {
171            // Split on last dot to handle cases like "schema.table.column"
172            if let Some(dot_pos) = column_name.rfind('.') {
173                let _table_or_alias = &column_name[..dot_pos];
174                let col_name = &column_name[dot_pos + 1..];
175
176                // For now, just use the column name part
177                // In the future, we could validate the table/alias part
178                debug!(
179                    "Resolving qualified column: {} -> {}",
180                    column_name, col_name
181                );
182                col_name.to_string()
183            } else {
184                column_name.to_string()
185            }
186        } else {
187            column_name.to_string()
188        };
189
190        let col_index = if let Some(idx) = self.table.get_column_index(&resolved_column) {
191            idx
192        } else if resolved_column != column_name {
193            // If not found, try the original name
194            if let Some(idx) = self.table.get_column_index(column_name) {
195                idx
196            } else {
197                let suggestion = self.find_similar_column(&resolved_column);
198                return Err(match suggestion {
199                    Some(similar) => anyhow!(
200                        "Column '{}' not found. Did you mean '{}'?",
201                        column_name,
202                        similar
203                    ),
204                    None => anyhow!("Column '{}' not found", column_name),
205                });
206            }
207        } else {
208            let suggestion = self.find_similar_column(&resolved_column);
209            return Err(match suggestion {
210                Some(similar) => anyhow!(
211                    "Column '{}' not found. Did you mean '{}'?",
212                    column_name,
213                    similar
214                ),
215                None => anyhow!("Column '{}' not found", column_name),
216            });
217        };
218
219        if row_index >= self.table.row_count() {
220            return Err(anyhow!("Row index {} out of bounds", row_index));
221        }
222
223        let row = self
224            .table
225            .get_row(row_index)
226            .ok_or_else(|| anyhow!("Row {} not found", row_index))?;
227
228        let value = row
229            .get(col_index)
230            .ok_or_else(|| anyhow!("Column index {} out of bounds for row", col_index))?;
231
232        Ok(value.clone())
233    }
234
235    /// Evaluate a number literal (handles both integers and floats)
236    fn evaluate_number_literal(&self, number_str: &str) -> Result<DataValue> {
237        // Try to parse as integer first
238        if let Ok(int_val) = number_str.parse::<i64>() {
239            return Ok(DataValue::Integer(int_val));
240        }
241
242        // If that fails, try as float
243        if let Ok(float_val) = number_str.parse::<f64>() {
244            return Ok(DataValue::Float(float_val));
245        }
246
247        Err(anyhow!("Invalid number literal: {}", number_str))
248    }
249
250    /// Evaluate a binary operation (arithmetic)
251    fn evaluate_binary_op(
252        &mut self,
253        left: &SqlExpression,
254        op: &str,
255        right: &SqlExpression,
256        row_index: usize,
257    ) -> Result<DataValue> {
258        let left_val = self.evaluate(left, row_index)?;
259        let right_val = self.evaluate(right, row_index)?;
260
261        debug!(
262            "ArithmeticEvaluator: {} {} {}",
263            self.format_value(&left_val),
264            op,
265            self.format_value(&right_val)
266        );
267
268        match op {
269            "+" => self.add_values(&left_val, &right_val),
270            "-" => self.subtract_values(&left_val, &right_val),
271            "*" => self.multiply_values(&left_val, &right_val),
272            "/" => self.divide_values(&left_val, &right_val),
273            "%" => {
274                // Modulo operator - call MOD function
275                let args = vec![left.clone(), right.clone()];
276                self.evaluate_function("MOD", &args, row_index)
277            }
278            // Comparison operators (return boolean results)
279            // Use centralized comparison logic for consistency
280            ">" | "<" | ">=" | "<=" | "=" | "!=" | "<>" => {
281                let result = compare_with_op(&left_val, &right_val, op, false);
282                Ok(DataValue::Boolean(result))
283            }
284            // IS NULL / IS NOT NULL operators
285            "IS NULL" => Ok(DataValue::Boolean(matches!(left_val, DataValue::Null))),
286            "IS NOT NULL" => Ok(DataValue::Boolean(!matches!(left_val, DataValue::Null))),
287            // Logical operators
288            "AND" => {
289                let left_bool = self.to_bool(&left_val)?;
290                let right_bool = self.to_bool(&right_val)?;
291                Ok(DataValue::Boolean(left_bool && right_bool))
292            }
293            "OR" => {
294                let left_bool = self.to_bool(&left_val)?;
295                let right_bool = self.to_bool(&right_val)?;
296                Ok(DataValue::Boolean(left_bool || right_bool))
297            }
298            _ => Err(anyhow!("Unsupported arithmetic operator: {}", op)),
299        }
300    }
301
302    /// Add two `DataValues` with type coercion
303    fn add_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
304        // NULL handling - any operation with NULL returns NULL
305        if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
306            return Ok(DataValue::Null);
307        }
308
309        match (left, right) {
310            (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a + b)),
311            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 + b)),
312            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a + *b as f64)),
313            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a + b)),
314            _ => Err(anyhow!("Cannot add {:?} and {:?}", left, right)),
315        }
316    }
317
318    /// Subtract two `DataValues` with type coercion
319    fn subtract_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
320        // NULL handling - any operation with NULL returns NULL
321        if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
322            return Ok(DataValue::Null);
323        }
324
325        match (left, right) {
326            (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a - b)),
327            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 - b)),
328            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a - *b as f64)),
329            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a - b)),
330            _ => Err(anyhow!("Cannot subtract {:?} and {:?}", left, right)),
331        }
332    }
333
334    /// Multiply two `DataValues` with type coercion
335    fn multiply_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
336        // NULL handling - any operation with NULL returns NULL
337        if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
338            return Ok(DataValue::Null);
339        }
340
341        match (left, right) {
342            (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a * b)),
343            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 * b)),
344            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a * *b as f64)),
345            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a * b)),
346            _ => Err(anyhow!("Cannot multiply {:?} and {:?}", left, right)),
347        }
348    }
349
350    /// Divide two `DataValues` with type coercion
351    fn divide_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
352        // NULL handling - any operation with NULL returns NULL
353        if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
354            return Ok(DataValue::Null);
355        }
356
357        // Check for division by zero first
358        let is_zero = match right {
359            DataValue::Integer(0) => true,
360            DataValue::Float(f) if *f == 0.0 => true, // Only check for exact zero, not epsilon
361            _ => false,
362        };
363
364        if is_zero {
365            return Err(anyhow!("Division by zero"));
366        }
367
368        match (left, right) {
369            (DataValue::Integer(a), DataValue::Integer(b)) => {
370                // Integer division - if result is exact, keep as int, otherwise promote to float
371                if a % b == 0 {
372                    Ok(DataValue::Integer(a / b))
373                } else {
374                    Ok(DataValue::Float(*a as f64 / *b as f64))
375                }
376            }
377            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 / b)),
378            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a / *b as f64)),
379            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a / b)),
380            _ => Err(anyhow!("Cannot divide {:?} and {:?}", left, right)),
381        }
382    }
383
384    /// Format a `DataValue` for debug output
385    fn format_value(&self, value: &DataValue) -> String {
386        match value {
387            DataValue::Integer(i) => i.to_string(),
388            DataValue::Float(f) => f.to_string(),
389            DataValue::String(s) => format!("'{s}'"),
390            _ => format!("{value:?}"),
391        }
392    }
393
394    /// Convert a DataValue to boolean for logical operations
395    fn to_bool(&self, value: &DataValue) -> Result<bool> {
396        match value {
397            DataValue::Boolean(b) => Ok(*b),
398            DataValue::Integer(i) => Ok(*i != 0),
399            DataValue::Float(f) => Ok(*f != 0.0),
400            DataValue::Null => Ok(false),
401            _ => Err(anyhow!("Cannot convert {:?} to boolean", value)),
402        }
403    }
404
405    /// Evaluate a function call
406    fn evaluate_function_with_distinct(
407        &mut self,
408        name: &str,
409        args: &[SqlExpression],
410        distinct: bool,
411        row_index: usize,
412    ) -> Result<DataValue> {
413        // If DISTINCT is specified, handle it specially for aggregate functions
414        if distinct {
415            let name_upper = name.to_uppercase();
416
417            // Check if it's an aggregate function in the registry
418            if self.aggregate_registry.is_aggregate(&name_upper) {
419                return self.evaluate_aggregate_with_distinct(&name_upper, args, row_index);
420            } else {
421                return Err(anyhow!(
422                    "DISTINCT can only be used with aggregate functions"
423                ));
424            }
425        }
426
427        // Otherwise, use the regular evaluation
428        self.evaluate_function(name, args, row_index)
429    }
430
431    fn evaluate_aggregate_with_distinct(
432        &mut self,
433        name: &str,
434        args: &[SqlExpression],
435        _row_index: usize,
436    ) -> Result<DataValue> {
437        let name_upper = name.to_uppercase();
438
439        // Check aggregate registry (DISTINCT handling)
440        if self.aggregate_registry.get(&name_upper).is_some() {
441            // Determine which rows to process first
442            let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
443                visible.clone()
444            } else {
445                (0..self.table.rows.len()).collect()
446            };
447
448            // Special handling for STRING_AGG with separator parameter
449            if name_upper == "STRING_AGG" && args.len() >= 2 {
450                // STRING_AGG(DISTINCT column, separator)
451                let mut state = crate::sql::aggregates::AggregateState::StringAgg(
452                    // Evaluate the separator (second argument) once
453                    if args.len() >= 2 {
454                        let separator = self.evaluate(&args[1], 0)?; // Separator doesn't depend on row
455                        match separator {
456                            DataValue::String(s) => crate::sql::aggregates::StringAggState::new(&s),
457                            DataValue::InternedString(s) => {
458                                crate::sql::aggregates::StringAggState::new(&s)
459                            }
460                            _ => crate::sql::aggregates::StringAggState::new(","), // Default separator
461                        }
462                    } else {
463                        crate::sql::aggregates::StringAggState::new(",")
464                    },
465                );
466
467                // Evaluate the first argument (column) for each row and accumulate
468                // Handle DISTINCT - use a HashSet to track seen values
469                let mut seen_values = HashSet::new();
470
471                for &row_idx in &rows_to_process {
472                    let value = self.evaluate(&args[0], row_idx)?;
473
474                    // Skip if we've seen this value
475                    if !seen_values.insert(value.clone()) {
476                        continue; // Skip duplicate values
477                    }
478
479                    // Now get the aggregate function and accumulate
480                    let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
481                    agg_func.accumulate(&mut state, &value)?;
482                }
483
484                // Finalize the aggregate
485                let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
486                return Ok(agg_func.finalize(state));
487            }
488
489            // For other aggregates with DISTINCT
490            // Evaluate the argument expression for each row
491            let mut vals = Vec::new();
492            for &row_idx in &rows_to_process {
493                if !args.is_empty() {
494                    let value = self.evaluate(&args[0], row_idx)?;
495                    vals.push(value);
496                }
497            }
498
499            // Deduplicate values for DISTINCT
500            let mut seen = HashSet::new();
501            let mut unique_values = Vec::new();
502            for value in vals {
503                if seen.insert(value.clone()) {
504                    unique_values.push(value);
505                }
506            }
507
508            // Now get the aggregate function and process
509            let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
510            let mut state = agg_func.init();
511
512            // Use unique values
513            for value in &unique_values {
514                agg_func.accumulate(&mut state, value)?;
515            }
516
517            return Ok(agg_func.finalize(state));
518        }
519
520        Err(anyhow!("Unknown aggregate function: {}", name))
521    }
522
523    fn evaluate_function(
524        &mut self,
525        name: &str,
526        args: &[SqlExpression],
527        row_index: usize,
528    ) -> Result<DataValue> {
529        // Check if this is an aggregate function
530        let name_upper = name.to_uppercase();
531
532        // Handle COUNT(*) special case
533        if name_upper == "COUNT" && args.len() == 1 {
534            match &args[0] {
535                SqlExpression::Column(col) if col == "*" => {
536                    // COUNT(*) - count all rows (or visible rows if filtered)
537                    let count = if let Some(ref visible) = self.visible_rows {
538                        visible.len() as i64
539                    } else {
540                        self.table.rows.len() as i64
541                    };
542                    return Ok(DataValue::Integer(count));
543                }
544                SqlExpression::StringLiteral(s) if s == "*" => {
545                    // COUNT(*) parsed as StringLiteral
546                    let count = if let Some(ref visible) = self.visible_rows {
547                        visible.len() as i64
548                    } else {
549                        self.table.rows.len() as i64
550                    };
551                    return Ok(DataValue::Integer(count));
552                }
553                _ => {
554                    // COUNT(column) - will be handled below
555                }
556            }
557        }
558
559        // Check aggregate registry
560        if self.aggregate_registry.get(&name_upper).is_some() {
561            // Determine which rows to process first
562            let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
563                visible.clone()
564            } else {
565                (0..self.table.rows.len()).collect()
566            };
567
568            // Special handling for STRING_AGG with separator parameter
569            if name_upper == "STRING_AGG" && args.len() >= 2 {
570                // STRING_AGG(column, separator) - without DISTINCT (handled separately)
571                let mut state = crate::sql::aggregates::AggregateState::StringAgg(
572                    // Evaluate the separator (second argument) once
573                    if args.len() >= 2 {
574                        let separator = self.evaluate(&args[1], 0)?; // Separator doesn't depend on row
575                        match separator {
576                            DataValue::String(s) => crate::sql::aggregates::StringAggState::new(&s),
577                            DataValue::InternedString(s) => {
578                                crate::sql::aggregates::StringAggState::new(&s)
579                            }
580                            _ => crate::sql::aggregates::StringAggState::new(","), // Default separator
581                        }
582                    } else {
583                        crate::sql::aggregates::StringAggState::new(",")
584                    },
585                );
586
587                // Evaluate the first argument (column) for each row and accumulate
588                for &row_idx in &rows_to_process {
589                    let value = self.evaluate(&args[0], row_idx)?;
590                    // Now get the aggregate function and accumulate
591                    let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
592                    agg_func.accumulate(&mut state, &value)?;
593                }
594
595                // Finalize the aggregate
596                let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
597                return Ok(agg_func.finalize(state));
598            }
599
600            // Evaluate arguments first if needed (to avoid borrow issues)
601            let values = if !args.is_empty()
602                && !(args.len() == 1 && matches!(&args[0], SqlExpression::Column(c) if c == "*"))
603            {
604                // Evaluate the argument expression for each row
605                let mut vals = Vec::new();
606                for &row_idx in &rows_to_process {
607                    let value = self.evaluate(&args[0], row_idx)?;
608                    vals.push(value);
609                }
610                Some(vals)
611            } else {
612                None
613            };
614
615            // Now get the aggregate function and process
616            let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
617            let mut state = agg_func.init();
618
619            if let Some(values) = values {
620                // Use evaluated values (DISTINCT is handled in evaluate_aggregate_with_distinct)
621                for value in &values {
622                    agg_func.accumulate(&mut state, value)?;
623                }
624            } else {
625                // COUNT(*) case
626                for _ in &rows_to_process {
627                    agg_func.accumulate(&mut state, &DataValue::Integer(1))?;
628                }
629            }
630
631            return Ok(agg_func.finalize(state));
632        }
633
634        // First check if this function exists in the registry
635        if self.function_registry.get(name).is_some() {
636            // Evaluate all arguments first to avoid borrow issues
637            let mut evaluated_args = Vec::new();
638            for arg in args {
639                evaluated_args.push(self.evaluate(arg, row_index)?);
640            }
641
642            // Get the function and call it
643            let func = self.function_registry.get(name).unwrap();
644            return func.evaluate(&evaluated_args);
645        }
646
647        // If not in registry, return error for unknown function
648        Err(anyhow!("Unknown function: {}", name))
649    }
650
651    /// Get or create a WindowContext for the given specification
652    fn get_or_create_window_context(&mut self, spec: &WindowSpec) -> Result<Arc<WindowContext>> {
653        // Create a key for caching based on the spec
654        let key = format!("{:?}", spec);
655
656        if let Some(context) = self.window_contexts.get(&key) {
657            return Ok(Arc::clone(context));
658        }
659
660        // Create a DataView from the table (with visible rows if filtered)
661        let data_view = if let Some(ref _visible_rows) = self.visible_rows {
662            // Create a filtered view
663            let view = DataView::new(Arc::new(self.table.clone()));
664            // Apply filtering based on visible rows
665            // Note: This is a simplified approach - in production we'd need proper filtering
666            view
667        } else {
668            DataView::new(Arc::new(self.table.clone()))
669        };
670
671        // Create the WindowContext with the full spec (including frame)
672        let context = WindowContext::new_with_spec(Arc::new(data_view), spec.clone())?;
673
674        let context = Arc::new(context);
675        self.window_contexts.insert(key, Arc::clone(&context));
676        Ok(context)
677    }
678
679    /// Evaluate a window function
680    fn evaluate_window_function(
681        &mut self,
682        name: &str,
683        args: &[SqlExpression],
684        spec: &WindowSpec,
685        row_index: usize,
686    ) -> Result<DataValue> {
687        let name_upper = name.to_uppercase();
688
689        // First check if this is a syntactic sugar function in the registry
690        debug!("Looking for window function {} in registry", name_upper);
691        if let Some(window_fn_arc) = self.window_function_registry.get(&name_upper) {
692            debug!("Found window function {} in registry", name_upper);
693
694            // Dereference to get the actual window function
695            let window_fn = window_fn_arc.as_ref();
696
697            // Validate arguments
698            window_fn.validate_args(args)?;
699
700            // Transform the window spec based on the function's requirements
701            let transformed_spec = window_fn.transform_window_spec(spec, args)?;
702
703            // Get or create the window context with the transformed spec
704            let context = self.get_or_create_window_context(&transformed_spec)?;
705
706            // Create an expression evaluator adapter
707            struct EvaluatorAdapter<'a, 'b> {
708                evaluator: &'a mut ArithmeticEvaluator<'b>,
709                row_index: usize,
710            }
711
712            impl<'a, 'b> ExpressionEvaluator for EvaluatorAdapter<'a, 'b> {
713                fn evaluate(
714                    &mut self,
715                    expr: &SqlExpression,
716                    _row_index: usize,
717                ) -> Result<DataValue> {
718                    self.evaluator.evaluate(expr, self.row_index)
719                }
720            }
721
722            let mut adapter = EvaluatorAdapter {
723                evaluator: self,
724                row_index,
725            };
726
727            // Call the window function's compute method
728            return window_fn.compute(&context, row_index, args, &mut adapter);
729        }
730
731        // Fall back to built-in window functions
732        let context = self.get_or_create_window_context(spec)?;
733
734        match name_upper.as_str() {
735            "LAG" => {
736                // LAG(column, offset, default)
737                if args.is_empty() {
738                    return Err(anyhow!("LAG requires at least 1 argument"));
739                }
740
741                // Get column name
742                let column = match &args[0] {
743                    SqlExpression::Column(col) => col.clone(),
744                    _ => return Err(anyhow!("LAG first argument must be a column")),
745                };
746
747                // Get offset (default 1)
748                let offset = if args.len() > 1 {
749                    match self.evaluate(&args[1], row_index)? {
750                        DataValue::Integer(i) => i as i32,
751                        _ => return Err(anyhow!("LAG offset must be an integer")),
752                    }
753                } else {
754                    1
755                };
756
757                // Get value at offset
758                Ok(context
759                    .get_offset_value(row_index, -offset, &column)
760                    .unwrap_or(DataValue::Null))
761            }
762            "LEAD" => {
763                // LEAD(column, offset, default)
764                if args.is_empty() {
765                    return Err(anyhow!("LEAD requires at least 1 argument"));
766                }
767
768                // Get column name
769                let column = match &args[0] {
770                    SqlExpression::Column(col) => col.clone(),
771                    _ => return Err(anyhow!("LEAD first argument must be a column")),
772                };
773
774                // Get offset (default 1)
775                let offset = if args.len() > 1 {
776                    match self.evaluate(&args[1], row_index)? {
777                        DataValue::Integer(i) => i as i32,
778                        _ => return Err(anyhow!("LEAD offset must be an integer")),
779                    }
780                } else {
781                    1
782                };
783
784                // Get value at offset
785                Ok(context
786                    .get_offset_value(row_index, offset, &column)
787                    .unwrap_or(DataValue::Null))
788            }
789            "ROW_NUMBER" => {
790                // ROW_NUMBER() - no arguments
791                Ok(DataValue::Integer(context.get_row_number(row_index) as i64))
792            }
793            "FIRST_VALUE" => {
794                // FIRST_VALUE(column) OVER (... ROWS ...)
795                if args.is_empty() {
796                    return Err(anyhow!("FIRST_VALUE requires 1 argument"));
797                }
798
799                let column = match &args[0] {
800                    SqlExpression::Column(col) => col.clone(),
801                    _ => return Err(anyhow!("FIRST_VALUE argument must be a column")),
802                };
803
804                // Use frame-aware version if frame is specified
805                if context.has_frame() {
806                    Ok(context
807                        .get_frame_first_value(row_index, &column)
808                        .unwrap_or(DataValue::Null))
809                } else {
810                    Ok(context
811                        .get_first_value(row_index, &column)
812                        .unwrap_or(DataValue::Null))
813                }
814            }
815            "LAST_VALUE" => {
816                // LAST_VALUE(column) OVER (... ROWS ...)
817                if args.is_empty() {
818                    return Err(anyhow!("LAST_VALUE requires 1 argument"));
819                }
820
821                let column = match &args[0] {
822                    SqlExpression::Column(col) => col.clone(),
823                    _ => return Err(anyhow!("LAST_VALUE argument must be a column")),
824                };
825
826                // Use frame-aware version if frame is specified
827                if context.has_frame() {
828                    Ok(context
829                        .get_frame_last_value(row_index, &column)
830                        .unwrap_or(DataValue::Null))
831                } else {
832                    Ok(context
833                        .get_last_value(row_index, &column)
834                        .unwrap_or(DataValue::Null))
835                }
836            }
837            "SUM" => {
838                // SUM(column) OVER (PARTITION BY ... ROWS n PRECEDING)
839                if args.is_empty() {
840                    return Err(anyhow!("SUM requires 1 argument"));
841                }
842
843                let column = match &args[0] {
844                    SqlExpression::Column(col) => col.clone(),
845                    _ => return Err(anyhow!("SUM argument must be a column")),
846                };
847
848                // Use frame-aware sum if frame is specified, otherwise use partition sum
849                if context.has_frame() {
850                    Ok(context
851                        .get_frame_sum(row_index, &column)
852                        .unwrap_or(DataValue::Null))
853                } else {
854                    Ok(context
855                        .get_partition_sum(row_index, &column)
856                        .unwrap_or(DataValue::Null))
857                }
858            }
859            "AVG" => {
860                // AVG(column) OVER (PARTITION BY ... ROWS n PRECEDING)
861                if args.is_empty() {
862                    return Err(anyhow!("AVG requires 1 argument"));
863                }
864
865                let column = match &args[0] {
866                    SqlExpression::Column(col) => col.clone(),
867                    _ => return Err(anyhow!("AVG argument must be a column")),
868                };
869
870                Ok(context
871                    .get_frame_avg(row_index, &column)
872                    .unwrap_or(DataValue::Null))
873            }
874            "STDDEV" | "STDEV" => {
875                // STDDEV(column) OVER (PARTITION BY ... ROWS n PRECEDING)
876                if args.is_empty() {
877                    return Err(anyhow!("STDDEV requires 1 argument"));
878                }
879
880                let column = match &args[0] {
881                    SqlExpression::Column(col) => col.clone(),
882                    _ => return Err(anyhow!("STDDEV argument must be a column")),
883                };
884
885                Ok(context
886                    .get_frame_stddev(row_index, &column)
887                    .unwrap_or(DataValue::Null))
888            }
889            "VARIANCE" | "VAR" => {
890                // VARIANCE(column) OVER (PARTITION BY ... ROWS n PRECEDING)
891                if args.is_empty() {
892                    return Err(anyhow!("VARIANCE requires 1 argument"));
893                }
894
895                let column = match &args[0] {
896                    SqlExpression::Column(col) => col.clone(),
897                    _ => return Err(anyhow!("VARIANCE argument must be a column")),
898                };
899
900                Ok(context
901                    .get_frame_variance(row_index, &column)
902                    .unwrap_or(DataValue::Null))
903            }
904            "MIN" => {
905                // MIN(column) OVER (PARTITION BY ... ROWS n PRECEDING)
906                if args.is_empty() {
907                    return Err(anyhow!("MIN requires 1 argument"));
908                }
909
910                let column = match &args[0] {
911                    SqlExpression::Column(col) => col.clone(),
912                    _ => return Err(anyhow!("MIN argument must be a column")),
913                };
914
915                let frame_rows = context.get_frame_rows(row_index);
916                if frame_rows.is_empty() {
917                    return Ok(DataValue::Null);
918                }
919
920                let source_table = context.source();
921                let col_idx = source_table
922                    .get_column_index(&column)
923                    .ok_or_else(|| anyhow!("Column '{}' not found", column))?;
924
925                let mut min_value: Option<DataValue> = None;
926                for &row_idx in &frame_rows {
927                    if let Some(value) = source_table.get_value(row_idx, col_idx) {
928                        if !matches!(value, DataValue::Null) {
929                            match &min_value {
930                                None => min_value = Some(value.clone()),
931                                Some(current_min) => {
932                                    if value < current_min {
933                                        min_value = Some(value.clone());
934                                    }
935                                }
936                            }
937                        }
938                    }
939                }
940
941                Ok(min_value.unwrap_or(DataValue::Null))
942            }
943            "MAX" => {
944                // MAX(column) OVER (PARTITION BY ... ROWS n PRECEDING)
945                if args.is_empty() {
946                    return Err(anyhow!("MAX requires 1 argument"));
947                }
948
949                let column = match &args[0] {
950                    SqlExpression::Column(col) => col.clone(),
951                    _ => return Err(anyhow!("MAX argument must be a column")),
952                };
953
954                let frame_rows = context.get_frame_rows(row_index);
955                if frame_rows.is_empty() {
956                    return Ok(DataValue::Null);
957                }
958
959                let source_table = context.source();
960                let col_idx = source_table
961                    .get_column_index(&column)
962                    .ok_or_else(|| anyhow!("Column '{}' not found", column))?;
963
964                let mut max_value: Option<DataValue> = None;
965                for &row_idx in &frame_rows {
966                    if let Some(value) = source_table.get_value(row_idx, col_idx) {
967                        if !matches!(value, DataValue::Null) {
968                            match &max_value {
969                                None => max_value = Some(value.clone()),
970                                Some(current_max) => {
971                                    if value > current_max {
972                                        max_value = Some(value.clone());
973                                    }
974                                }
975                            }
976                        }
977                    }
978                }
979
980                Ok(max_value.unwrap_or(DataValue::Null))
981            }
982            "COUNT" => {
983                // COUNT(*) or COUNT(column) OVER (PARTITION BY ... ROWS n PRECEDING)
984                // Use frame-aware count if frame is specified, otherwise use partition count
985
986                if args.is_empty() {
987                    // COUNT(*) OVER (...)
988                    if context.has_frame() {
989                        Ok(context
990                            .get_frame_count(row_index, None)
991                            .unwrap_or(DataValue::Null))
992                    } else {
993                        Ok(context
994                            .get_partition_count(row_index, None)
995                            .unwrap_or(DataValue::Null))
996                    }
997                } else {
998                    // Check for COUNT(*)
999                    let column = match &args[0] {
1000                        SqlExpression::Column(col) => {
1001                            if col == "*" {
1002                                // COUNT(*) - count all rows
1003                                if context.has_frame() {
1004                                    return Ok(context
1005                                        .get_frame_count(row_index, None)
1006                                        .unwrap_or(DataValue::Null));
1007                                } else {
1008                                    return Ok(context
1009                                        .get_partition_count(row_index, None)
1010                                        .unwrap_or(DataValue::Null));
1011                                }
1012                            }
1013                            col.clone()
1014                        }
1015                        SqlExpression::StringLiteral(s) if s == "*" => {
1016                            // COUNT(*) as StringLiteral
1017                            if context.has_frame() {
1018                                return Ok(context
1019                                    .get_frame_count(row_index, None)
1020                                    .unwrap_or(DataValue::Null));
1021                            } else {
1022                                return Ok(context
1023                                    .get_partition_count(row_index, None)
1024                                    .unwrap_or(DataValue::Null));
1025                            }
1026                        }
1027                        _ => return Err(anyhow!("COUNT argument must be a column or *")),
1028                    };
1029
1030                    // COUNT(column) - count non-null values
1031                    if context.has_frame() {
1032                        Ok(context
1033                            .get_frame_count(row_index, Some(&column))
1034                            .unwrap_or(DataValue::Null))
1035                    } else {
1036                        Ok(context
1037                            .get_partition_count(row_index, Some(&column))
1038                            .unwrap_or(DataValue::Null))
1039                    }
1040                }
1041            }
1042            _ => Err(anyhow!("Unknown window function: {}", name)),
1043        }
1044    }
1045
1046    /// Evaluate a method call on a column (e.g., `column.Trim()`)
1047    fn evaluate_method_call(
1048        &mut self,
1049        object: &str,
1050        method: &str,
1051        args: &[SqlExpression],
1052        row_index: usize,
1053    ) -> Result<DataValue> {
1054        // Get column value
1055        let col_index = self.table.get_column_index(object).ok_or_else(|| {
1056            let suggestion = self.find_similar_column(object);
1057            match suggestion {
1058                Some(similar) => {
1059                    anyhow!("Column '{}' not found. Did you mean '{}'?", object, similar)
1060                }
1061                None => anyhow!("Column '{}' not found", object),
1062            }
1063        })?;
1064
1065        let cell_value = self.table.get_value(row_index, col_index).cloned();
1066
1067        self.evaluate_method_on_value(
1068            &cell_value.unwrap_or(DataValue::Null),
1069            method,
1070            args,
1071            row_index,
1072        )
1073    }
1074
1075    /// Evaluate a method on a value
1076    fn evaluate_method_on_value(
1077        &mut self,
1078        value: &DataValue,
1079        method: &str,
1080        args: &[SqlExpression],
1081        row_index: usize,
1082    ) -> Result<DataValue> {
1083        // First, try to proxy the method through the function registry
1084        // Many string methods have corresponding functions (TRIM, LENGTH, CONTAINS, etc.)
1085
1086        // Map method names to function names (case-insensitive matching)
1087        let function_name = match method.to_lowercase().as_str() {
1088            "trim" => "TRIM",
1089            "trimstart" | "trimbegin" => "TRIMSTART",
1090            "trimend" => "TRIMEND",
1091            "length" | "len" => "LENGTH",
1092            "contains" => "CONTAINS",
1093            "startswith" => "STARTSWITH",
1094            "endswith" => "ENDSWITH",
1095            "indexof" => "INDEXOF",
1096            _ => method, // Try the method name as-is
1097        };
1098
1099        // Check if we have this function in the registry
1100        if self.function_registry.get(function_name).is_some() {
1101            debug!(
1102                "Proxying method '{}' through function registry as '{}'",
1103                method, function_name
1104            );
1105
1106            // Prepare arguments: receiver is the first argument, followed by method args
1107            let mut func_args = vec![value.clone()];
1108
1109            // Evaluate method arguments and add them
1110            for arg in args {
1111                func_args.push(self.evaluate(arg, row_index)?);
1112            }
1113
1114            // Get the function and call it
1115            let func = self.function_registry.get(function_name).unwrap();
1116            return func.evaluate(&func_args);
1117        }
1118
1119        // If not in registry, the method is not supported
1120        // All methods should be registered in the function registry
1121        Err(anyhow!(
1122            "Method '{}' not found. It should be registered in the function registry.",
1123            method
1124        ))
1125    }
1126
1127    /// Evaluate a CASE expression
1128    fn evaluate_case_expression(
1129        &mut self,
1130        when_branches: &[crate::sql::recursive_parser::WhenBranch],
1131        else_branch: &Option<Box<SqlExpression>>,
1132        row_index: usize,
1133    ) -> Result<DataValue> {
1134        debug!(
1135            "ArithmeticEvaluator: evaluating CASE expression for row {}",
1136            row_index
1137        );
1138
1139        // Evaluate each WHEN condition in order
1140        for branch in when_branches {
1141            // Evaluate the condition as a boolean
1142            let condition_result = self.evaluate_condition_as_bool(&branch.condition, row_index)?;
1143
1144            if condition_result {
1145                debug!("CASE: WHEN condition matched, evaluating result expression");
1146                return self.evaluate(&branch.result, row_index);
1147            }
1148        }
1149
1150        // If no WHEN condition matched, evaluate ELSE clause (or return NULL)
1151        if let Some(else_expr) = else_branch {
1152            debug!("CASE: No WHEN matched, evaluating ELSE expression");
1153            self.evaluate(else_expr, row_index)
1154        } else {
1155            debug!("CASE: No WHEN matched and no ELSE, returning NULL");
1156            Ok(DataValue::Null)
1157        }
1158    }
1159
1160    /// Helper method to evaluate an expression as a boolean (for CASE WHEN conditions)
1161    fn evaluate_condition_as_bool(
1162        &mut self,
1163        expr: &SqlExpression,
1164        row_index: usize,
1165    ) -> Result<bool> {
1166        let value = self.evaluate(expr, row_index)?;
1167
1168        match value {
1169            DataValue::Boolean(b) => Ok(b),
1170            DataValue::Integer(i) => Ok(i != 0),
1171            DataValue::Float(f) => Ok(f != 0.0),
1172            DataValue::Null => Ok(false),
1173            DataValue::String(s) => Ok(!s.is_empty()),
1174            DataValue::InternedString(s) => Ok(!s.is_empty()),
1175            _ => Ok(true), // Other types are considered truthy
1176        }
1177    }
1178}
1179
1180#[cfg(test)]
1181mod tests {
1182    use super::*;
1183    use crate::data::datatable::{DataColumn, DataRow};
1184
1185    fn create_test_table() -> DataTable {
1186        let mut table = DataTable::new("test");
1187        table.add_column(DataColumn::new("a"));
1188        table.add_column(DataColumn::new("b"));
1189        table.add_column(DataColumn::new("c"));
1190
1191        table
1192            .add_row(DataRow::new(vec![
1193                DataValue::Integer(10),
1194                DataValue::Float(2.5),
1195                DataValue::Integer(4),
1196            ]))
1197            .unwrap();
1198
1199        table
1200    }
1201
1202    #[test]
1203    fn test_evaluate_column() {
1204        let table = create_test_table();
1205        let mut evaluator = ArithmeticEvaluator::new(&table);
1206
1207        let expr = SqlExpression::Column("a".to_string());
1208        let result = evaluator.evaluate(&expr, 0).unwrap();
1209        assert_eq!(result, DataValue::Integer(10));
1210    }
1211
1212    #[test]
1213    fn test_evaluate_number_literal() {
1214        let table = create_test_table();
1215        let mut evaluator = ArithmeticEvaluator::new(&table);
1216
1217        let expr = SqlExpression::NumberLiteral("42".to_string());
1218        let result = evaluator.evaluate(&expr, 0).unwrap();
1219        assert_eq!(result, DataValue::Integer(42));
1220
1221        let expr = SqlExpression::NumberLiteral("3.14".to_string());
1222        let result = evaluator.evaluate(&expr, 0).unwrap();
1223        assert_eq!(result, DataValue::Float(3.14));
1224    }
1225
1226    #[test]
1227    fn test_add_values() {
1228        let table = create_test_table();
1229        let mut evaluator = ArithmeticEvaluator::new(&table);
1230
1231        // Integer + Integer
1232        let result = evaluator
1233            .add_values(&DataValue::Integer(5), &DataValue::Integer(3))
1234            .unwrap();
1235        assert_eq!(result, DataValue::Integer(8));
1236
1237        // Integer + Float
1238        let result = evaluator
1239            .add_values(&DataValue::Integer(5), &DataValue::Float(2.5))
1240            .unwrap();
1241        assert_eq!(result, DataValue::Float(7.5));
1242    }
1243
1244    #[test]
1245    fn test_multiply_values() {
1246        let table = create_test_table();
1247        let mut evaluator = ArithmeticEvaluator::new(&table);
1248
1249        // Integer * Float
1250        let result = evaluator
1251            .multiply_values(&DataValue::Integer(4), &DataValue::Float(2.5))
1252            .unwrap();
1253        assert_eq!(result, DataValue::Float(10.0));
1254    }
1255
1256    #[test]
1257    fn test_divide_values() {
1258        let table = create_test_table();
1259        let mut evaluator = ArithmeticEvaluator::new(&table);
1260
1261        // Exact division
1262        let result = evaluator
1263            .divide_values(&DataValue::Integer(10), &DataValue::Integer(2))
1264            .unwrap();
1265        assert_eq!(result, DataValue::Integer(5));
1266
1267        // Non-exact division
1268        let result = evaluator
1269            .divide_values(&DataValue::Integer(10), &DataValue::Integer(3))
1270            .unwrap();
1271        assert_eq!(result, DataValue::Float(10.0 / 3.0));
1272    }
1273
1274    #[test]
1275    fn test_division_by_zero() {
1276        let table = create_test_table();
1277        let mut evaluator = ArithmeticEvaluator::new(&table);
1278
1279        let result = evaluator.divide_values(&DataValue::Integer(10), &DataValue::Integer(0));
1280        assert!(result.is_err());
1281        assert!(result.unwrap_err().to_string().contains("Division by zero"));
1282    }
1283
1284    #[test]
1285    fn test_binary_op_expression() {
1286        let table = create_test_table();
1287        let mut evaluator = ArithmeticEvaluator::new(&table);
1288
1289        // a * b where a=10, b=2.5
1290        let expr = SqlExpression::BinaryOp {
1291            left: Box::new(SqlExpression::Column("a".to_string())),
1292            op: "*".to_string(),
1293            right: Box::new(SqlExpression::Column("b".to_string())),
1294        };
1295
1296        let result = evaluator.evaluate(&expr, 0).unwrap();
1297        assert_eq!(result, DataValue::Float(25.0));
1298    }
1299}