1use crate::config::global::get_date_notation;
2use crate::data::data_view::DataView;
3use crate::data::datatable::{DataTable, DataValue};
4use crate::data::value_comparisons::compare_with_op;
5use crate::sql::aggregates::AggregateRegistry;
6use crate::sql::functions::FunctionRegistry;
7use crate::sql::parser::ast::WindowSpec;
8use crate::sql::recursive_parser::SqlExpression;
9use crate::sql::window_context::WindowContext;
10use crate::sql::window_functions::{ExpressionEvaluator, WindowFunctionRegistry};
11use anyhow::{anyhow, Result};
12use std::collections::{HashMap, HashSet};
13use std::sync::Arc;
14use tracing::debug;
15
16pub struct ArithmeticEvaluator<'a> {
19    table: &'a DataTable,
20    date_notation: String,
21    function_registry: Arc<FunctionRegistry>,
22    aggregate_registry: Arc<AggregateRegistry>,
23    window_function_registry: Arc<WindowFunctionRegistry>,
24    visible_rows: Option<Vec<usize>>, window_contexts: HashMap<String, Arc<WindowContext>>, table_aliases: HashMap<String, String>, }
28
29impl<'a> ArithmeticEvaluator<'a> {
30    #[must_use]
31    pub fn new(table: &'a DataTable) -> Self {
32        Self {
33            table,
34            date_notation: get_date_notation(),
35            function_registry: Arc::new(FunctionRegistry::new()),
36            aggregate_registry: Arc::new(AggregateRegistry::new()),
37            window_function_registry: Arc::new(WindowFunctionRegistry::new()),
38            visible_rows: None,
39            window_contexts: HashMap::new(),
40            table_aliases: HashMap::new(),
41        }
42    }
43
44    #[must_use]
45    pub fn with_date_notation(table: &'a DataTable, date_notation: String) -> Self {
46        Self {
47            table,
48            date_notation,
49            function_registry: Arc::new(FunctionRegistry::new()),
50            aggregate_registry: Arc::new(AggregateRegistry::new()),
51            window_function_registry: Arc::new(WindowFunctionRegistry::new()),
52            visible_rows: None,
53            window_contexts: HashMap::new(),
54            table_aliases: HashMap::new(),
55        }
56    }
57
58    #[must_use]
60    pub fn with_visible_rows(mut self, rows: Vec<usize>) -> Self {
61        self.visible_rows = Some(rows);
62        self
63    }
64
65    #[must_use]
67    pub fn with_table_aliases(mut self, aliases: HashMap<String, String>) -> Self {
68        self.table_aliases = aliases;
69        self
70    }
71
72    #[must_use]
73    pub fn with_date_notation_and_registry(
74        table: &'a DataTable,
75        date_notation: String,
76        function_registry: Arc<FunctionRegistry>,
77    ) -> Self {
78        Self {
79            table,
80            date_notation,
81            function_registry,
82            aggregate_registry: Arc::new(AggregateRegistry::new()),
83            window_function_registry: Arc::new(WindowFunctionRegistry::new()),
84            visible_rows: None,
85            window_contexts: HashMap::new(),
86            table_aliases: HashMap::new(),
87        }
88    }
89
90    fn find_similar_column(&self, name: &str) -> Option<String> {
92        let columns = self.table.column_names();
93        let mut best_match: Option<(String, usize)> = None;
94
95        for col in columns {
96            let distance = self.edit_distance(&col.to_lowercase(), &name.to_lowercase());
97            let max_distance = if name.len() > 10 { 3 } else { 2 };
100            if distance <= max_distance {
101                match &best_match {
102                    None => best_match = Some((col, distance)),
103                    Some((_, best_dist)) if distance < *best_dist => {
104                        best_match = Some((col, distance));
105                    }
106                    _ => {}
107                }
108            }
109        }
110
111        best_match.map(|(name, _)| name)
112    }
113
114    fn edit_distance(&self, s1: &str, s2: &str) -> usize {
116        crate::sql::functions::string_methods::EditDistanceFunction::calculate_edit_distance(s1, s2)
118    }
119
120    pub fn evaluate(&mut self, expr: &SqlExpression, row_index: usize) -> Result<DataValue> {
122        debug!(
123            "ArithmeticEvaluator: evaluating {:?} for row {}",
124            expr, row_index
125        );
126
127        match expr {
128            SqlExpression::Column(column_name) => self.evaluate_column(column_name, row_index),
129            SqlExpression::StringLiteral(s) => Ok(DataValue::String(s.clone())),
130            SqlExpression::BooleanLiteral(b) => Ok(DataValue::Boolean(*b)),
131            SqlExpression::NumberLiteral(n) => self.evaluate_number_literal(n),
132            SqlExpression::Null => Ok(DataValue::Null),
133            SqlExpression::BinaryOp { left, op, right } => {
134                self.evaluate_binary_op(left, op, right, row_index)
135            }
136            SqlExpression::FunctionCall {
137                name,
138                args,
139                distinct,
140            } => self.evaluate_function_with_distinct(name, args, *distinct, row_index),
141            SqlExpression::WindowFunction {
142                name,
143                args,
144                window_spec,
145            } => self.evaluate_window_function(name, args, window_spec, row_index),
146            SqlExpression::MethodCall {
147                object,
148                method,
149                args,
150            } => self.evaluate_method_call(object, method, args, row_index),
151            SqlExpression::ChainedMethodCall { base, method, args } => {
152                let base_value = self.evaluate(base, row_index)?;
154                self.evaluate_method_on_value(&base_value, method, args, row_index)
155            }
156            SqlExpression::CaseExpression {
157                when_branches,
158                else_branch,
159            } => self.evaluate_case_expression(when_branches, else_branch, row_index),
160            SqlExpression::SimpleCaseExpression {
161                expr,
162                when_branches,
163                else_branch,
164            } => self.evaluate_simple_case_expression(expr, when_branches, else_branch, row_index),
165            _ => Err(anyhow!(
166                "Unsupported expression type for arithmetic evaluation: {:?}",
167                expr
168            )),
169        }
170    }
171
172    fn evaluate_column(&self, column_name: &str, row_index: usize) -> Result<DataValue> {
174        let resolved_column = if column_name.contains('.') {
176            if let Some(dot_pos) = column_name.rfind('.') {
178                let _table_or_alias = &column_name[..dot_pos];
179                let col_name = &column_name[dot_pos + 1..];
180
181                debug!(
184                    "Resolving qualified column: {} -> {}",
185                    column_name, col_name
186                );
187                col_name.to_string()
188            } else {
189                column_name.to_string()
190            }
191        } else {
192            column_name.to_string()
193        };
194
195        let col_index = if let Some(idx) = self.table.get_column_index(&resolved_column) {
196            idx
197        } else if resolved_column != column_name {
198            if let Some(idx) = self.table.get_column_index(column_name) {
200                idx
201            } else {
202                let suggestion = self.find_similar_column(&resolved_column);
203                return Err(match suggestion {
204                    Some(similar) => anyhow!(
205                        "Column '{}' not found. Did you mean '{}'?",
206                        column_name,
207                        similar
208                    ),
209                    None => anyhow!("Column '{}' not found", column_name),
210                });
211            }
212        } else {
213            let suggestion = self.find_similar_column(&resolved_column);
214            return Err(match suggestion {
215                Some(similar) => anyhow!(
216                    "Column '{}' not found. Did you mean '{}'?",
217                    column_name,
218                    similar
219                ),
220                None => anyhow!("Column '{}' not found", column_name),
221            });
222        };
223
224        if row_index >= self.table.row_count() {
225            return Err(anyhow!("Row index {} out of bounds", row_index));
226        }
227
228        let row = self
229            .table
230            .get_row(row_index)
231            .ok_or_else(|| anyhow!("Row {} not found", row_index))?;
232
233        let value = row
234            .get(col_index)
235            .ok_or_else(|| anyhow!("Column index {} out of bounds for row", col_index))?;
236
237        Ok(value.clone())
238    }
239
240    fn evaluate_number_literal(&self, number_str: &str) -> Result<DataValue> {
242        if let Ok(int_val) = number_str.parse::<i64>() {
244            return Ok(DataValue::Integer(int_val));
245        }
246
247        if let Ok(float_val) = number_str.parse::<f64>() {
249            return Ok(DataValue::Float(float_val));
250        }
251
252        Err(anyhow!("Invalid number literal: {}", number_str))
253    }
254
255    fn evaluate_binary_op(
257        &mut self,
258        left: &SqlExpression,
259        op: &str,
260        right: &SqlExpression,
261        row_index: usize,
262    ) -> Result<DataValue> {
263        let left_val = self.evaluate(left, row_index)?;
264        let right_val = self.evaluate(right, row_index)?;
265
266        debug!(
267            "ArithmeticEvaluator: {} {} {}",
268            self.format_value(&left_val),
269            op,
270            self.format_value(&right_val)
271        );
272
273        match op {
274            "+" => self.add_values(&left_val, &right_val),
275            "-" => self.subtract_values(&left_val, &right_val),
276            "*" => self.multiply_values(&left_val, &right_val),
277            "/" => self.divide_values(&left_val, &right_val),
278            "%" => {
279                let args = vec![left.clone(), right.clone()];
281                self.evaluate_function("MOD", &args, row_index)
282            }
283            ">" | "<" | ">=" | "<=" | "=" | "!=" | "<>" => {
286                let result = compare_with_op(&left_val, &right_val, op, false);
287                Ok(DataValue::Boolean(result))
288            }
289            "IS NULL" => Ok(DataValue::Boolean(matches!(left_val, DataValue::Null))),
291            "IS NOT NULL" => Ok(DataValue::Boolean(!matches!(left_val, DataValue::Null))),
292            "AND" => {
294                let left_bool = self.to_bool(&left_val)?;
295                let right_bool = self.to_bool(&right_val)?;
296                Ok(DataValue::Boolean(left_bool && right_bool))
297            }
298            "OR" => {
299                let left_bool = self.to_bool(&left_val)?;
300                let right_bool = self.to_bool(&right_val)?;
301                Ok(DataValue::Boolean(left_bool || right_bool))
302            }
303            _ => Err(anyhow!("Unsupported arithmetic operator: {}", op)),
304        }
305    }
306
307    fn add_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
309        if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
311            return Ok(DataValue::Null);
312        }
313
314        match (left, right) {
315            (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a + b)),
316            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 + b)),
317            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a + *b as f64)),
318            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a + b)),
319            _ => Err(anyhow!("Cannot add {:?} and {:?}", left, right)),
320        }
321    }
322
323    fn subtract_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
325        if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
327            return Ok(DataValue::Null);
328        }
329
330        match (left, right) {
331            (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a - b)),
332            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 - b)),
333            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a - *b as f64)),
334            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a - b)),
335            _ => Err(anyhow!("Cannot subtract {:?} and {:?}", left, right)),
336        }
337    }
338
339    fn multiply_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
341        if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
343            return Ok(DataValue::Null);
344        }
345
346        match (left, right) {
347            (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a * b)),
348            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 * b)),
349            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a * *b as f64)),
350            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a * b)),
351            _ => Err(anyhow!("Cannot multiply {:?} and {:?}", left, right)),
352        }
353    }
354
355    fn divide_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
357        if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
359            return Ok(DataValue::Null);
360        }
361
362        let is_zero = match right {
364            DataValue::Integer(0) => true,
365            DataValue::Float(f) if *f == 0.0 => true, _ => false,
367        };
368
369        if is_zero {
370            return Err(anyhow!("Division by zero"));
371        }
372
373        match (left, right) {
374            (DataValue::Integer(a), DataValue::Integer(b)) => {
375                if a % b == 0 {
377                    Ok(DataValue::Integer(a / b))
378                } else {
379                    Ok(DataValue::Float(*a as f64 / *b as f64))
380                }
381            }
382            (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 / b)),
383            (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a / *b as f64)),
384            (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a / b)),
385            _ => Err(anyhow!("Cannot divide {:?} and {:?}", left, right)),
386        }
387    }
388
389    fn format_value(&self, value: &DataValue) -> String {
391        match value {
392            DataValue::Integer(i) => i.to_string(),
393            DataValue::Float(f) => f.to_string(),
394            DataValue::String(s) => format!("'{s}'"),
395            _ => format!("{value:?}"),
396        }
397    }
398
399    fn to_bool(&self, value: &DataValue) -> Result<bool> {
401        match value {
402            DataValue::Boolean(b) => Ok(*b),
403            DataValue::Integer(i) => Ok(*i != 0),
404            DataValue::Float(f) => Ok(*f != 0.0),
405            DataValue::Null => Ok(false),
406            _ => Err(anyhow!("Cannot convert {:?} to boolean", value)),
407        }
408    }
409
410    fn evaluate_function_with_distinct(
412        &mut self,
413        name: &str,
414        args: &[SqlExpression],
415        distinct: bool,
416        row_index: usize,
417    ) -> Result<DataValue> {
418        if distinct {
420            let name_upper = name.to_uppercase();
421
422            if self.aggregate_registry.is_aggregate(&name_upper) {
424                return self.evaluate_aggregate_with_distinct(&name_upper, args, row_index);
425            } else {
426                return Err(anyhow!(
427                    "DISTINCT can only be used with aggregate functions"
428                ));
429            }
430        }
431
432        self.evaluate_function(name, args, row_index)
434    }
435
436    fn evaluate_aggregate_with_distinct(
437        &mut self,
438        name: &str,
439        args: &[SqlExpression],
440        _row_index: usize,
441    ) -> Result<DataValue> {
442        let name_upper = name.to_uppercase();
443
444        if self.aggregate_registry.get(&name_upper).is_some() {
446            let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
448                visible.clone()
449            } else {
450                (0..self.table.rows.len()).collect()
451            };
452
453            if name_upper == "STRING_AGG" && args.len() >= 2 {
455                let mut state = crate::sql::aggregates::AggregateState::StringAgg(
457                    if args.len() >= 2 {
459                        let separator = self.evaluate(&args[1], 0)?; match separator {
461                            DataValue::String(s) => crate::sql::aggregates::StringAggState::new(&s),
462                            DataValue::InternedString(s) => {
463                                crate::sql::aggregates::StringAggState::new(&s)
464                            }
465                            _ => crate::sql::aggregates::StringAggState::new(","), }
467                    } else {
468                        crate::sql::aggregates::StringAggState::new(",")
469                    },
470                );
471
472                let mut seen_values = HashSet::new();
475
476                for &row_idx in &rows_to_process {
477                    let value = self.evaluate(&args[0], row_idx)?;
478
479                    if !seen_values.insert(value.clone()) {
481                        continue; }
483
484                    let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
486                    agg_func.accumulate(&mut state, &value)?;
487                }
488
489                let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
491                return Ok(agg_func.finalize(state));
492            }
493
494            let mut vals = Vec::new();
497            for &row_idx in &rows_to_process {
498                if !args.is_empty() {
499                    let value = self.evaluate(&args[0], row_idx)?;
500                    vals.push(value);
501                }
502            }
503
504            let mut seen = HashSet::new();
506            let mut unique_values = Vec::new();
507            for value in vals {
508                if seen.insert(value.clone()) {
509                    unique_values.push(value);
510                }
511            }
512
513            let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
515            let mut state = agg_func.init();
516
517            for value in &unique_values {
519                agg_func.accumulate(&mut state, value)?;
520            }
521
522            return Ok(agg_func.finalize(state));
523        }
524
525        Err(anyhow!("Unknown aggregate function: {}", name))
526    }
527
528    fn evaluate_function(
529        &mut self,
530        name: &str,
531        args: &[SqlExpression],
532        row_index: usize,
533    ) -> Result<DataValue> {
534        let name_upper = name.to_uppercase();
536
537        if name_upper == "COUNT" && args.len() == 1 {
539            match &args[0] {
540                SqlExpression::Column(col) if col == "*" => {
541                    let count = if let Some(ref visible) = self.visible_rows {
543                        visible.len() as i64
544                    } else {
545                        self.table.rows.len() as i64
546                    };
547                    return Ok(DataValue::Integer(count));
548                }
549                SqlExpression::StringLiteral(s) if s == "*" => {
550                    let count = if let Some(ref visible) = self.visible_rows {
552                        visible.len() as i64
553                    } else {
554                        self.table.rows.len() as i64
555                    };
556                    return Ok(DataValue::Integer(count));
557                }
558                _ => {
559                    }
561            }
562        }
563
564        if self.aggregate_registry.get(&name_upper).is_some() {
566            let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
568                visible.clone()
569            } else {
570                (0..self.table.rows.len()).collect()
571            };
572
573            if name_upper == "STRING_AGG" && args.len() >= 2 {
575                let mut state = crate::sql::aggregates::AggregateState::StringAgg(
577                    if args.len() >= 2 {
579                        let separator = self.evaluate(&args[1], 0)?; match separator {
581                            DataValue::String(s) => crate::sql::aggregates::StringAggState::new(&s),
582                            DataValue::InternedString(s) => {
583                                crate::sql::aggregates::StringAggState::new(&s)
584                            }
585                            _ => crate::sql::aggregates::StringAggState::new(","), }
587                    } else {
588                        crate::sql::aggregates::StringAggState::new(",")
589                    },
590                );
591
592                for &row_idx in &rows_to_process {
594                    let value = self.evaluate(&args[0], row_idx)?;
595                    let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
597                    agg_func.accumulate(&mut state, &value)?;
598                }
599
600                let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
602                return Ok(agg_func.finalize(state));
603            }
604
605            let values = if !args.is_empty()
607                && !(args.len() == 1 && matches!(&args[0], SqlExpression::Column(c) if c == "*"))
608            {
609                let mut vals = Vec::new();
611                for &row_idx in &rows_to_process {
612                    let value = self.evaluate(&args[0], row_idx)?;
613                    vals.push(value);
614                }
615                Some(vals)
616            } else {
617                None
618            };
619
620            let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
622            let mut state = agg_func.init();
623
624            if let Some(values) = values {
625                for value in &values {
627                    agg_func.accumulate(&mut state, value)?;
628                }
629            } else {
630                for _ in &rows_to_process {
632                    agg_func.accumulate(&mut state, &DataValue::Integer(1))?;
633                }
634            }
635
636            return Ok(agg_func.finalize(state));
637        }
638
639        if self.function_registry.get(name).is_some() {
641            let mut evaluated_args = Vec::new();
643            for arg in args {
644                evaluated_args.push(self.evaluate(arg, row_index)?);
645            }
646
647            let func = self.function_registry.get(name).unwrap();
649            return func.evaluate(&evaluated_args);
650        }
651
652        Err(anyhow!("Unknown function: {}", name))
654    }
655
656    fn get_or_create_window_context(&mut self, spec: &WindowSpec) -> Result<Arc<WindowContext>> {
658        let key = format!("{:?}", spec);
660
661        if let Some(context) = self.window_contexts.get(&key) {
662            return Ok(Arc::clone(context));
663        }
664
665        let data_view = if let Some(ref _visible_rows) = self.visible_rows {
667            let view = DataView::new(Arc::new(self.table.clone()));
669            view
672        } else {
673            DataView::new(Arc::new(self.table.clone()))
674        };
675
676        let context = WindowContext::new_with_spec(Arc::new(data_view), spec.clone())?;
678
679        let context = Arc::new(context);
680        self.window_contexts.insert(key, Arc::clone(&context));
681        Ok(context)
682    }
683
684    fn evaluate_window_function(
686        &mut self,
687        name: &str,
688        args: &[SqlExpression],
689        spec: &WindowSpec,
690        row_index: usize,
691    ) -> Result<DataValue> {
692        let name_upper = name.to_uppercase();
693
694        debug!("Looking for window function {} in registry", name_upper);
696        if let Some(window_fn_arc) = self.window_function_registry.get(&name_upper) {
697            debug!("Found window function {} in registry", name_upper);
698
699            let window_fn = window_fn_arc.as_ref();
701
702            window_fn.validate_args(args)?;
704
705            let transformed_spec = window_fn.transform_window_spec(spec, args)?;
707
708            let context = self.get_or_create_window_context(&transformed_spec)?;
710
711            struct EvaluatorAdapter<'a, 'b> {
713                evaluator: &'a mut ArithmeticEvaluator<'b>,
714                row_index: usize,
715            }
716
717            impl<'a, 'b> ExpressionEvaluator for EvaluatorAdapter<'a, 'b> {
718                fn evaluate(
719                    &mut self,
720                    expr: &SqlExpression,
721                    _row_index: usize,
722                ) -> Result<DataValue> {
723                    self.evaluator.evaluate(expr, self.row_index)
724                }
725            }
726
727            let mut adapter = EvaluatorAdapter {
728                evaluator: self,
729                row_index,
730            };
731
732            return window_fn.compute(&context, row_index, args, &mut adapter);
734        }
735
736        let context = self.get_or_create_window_context(spec)?;
738
739        match name_upper.as_str() {
740            "LAG" => {
741                if args.is_empty() {
743                    return Err(anyhow!("LAG requires at least 1 argument"));
744                }
745
746                let column = match &args[0] {
748                    SqlExpression::Column(col) => col.clone(),
749                    _ => return Err(anyhow!("LAG first argument must be a column")),
750                };
751
752                let offset = if args.len() > 1 {
754                    match self.evaluate(&args[1], row_index)? {
755                        DataValue::Integer(i) => i as i32,
756                        _ => return Err(anyhow!("LAG offset must be an integer")),
757                    }
758                } else {
759                    1
760                };
761
762                Ok(context
764                    .get_offset_value(row_index, -offset, &column)
765                    .unwrap_or(DataValue::Null))
766            }
767            "LEAD" => {
768                if args.is_empty() {
770                    return Err(anyhow!("LEAD requires at least 1 argument"));
771                }
772
773                let column = match &args[0] {
775                    SqlExpression::Column(col) => col.clone(),
776                    _ => return Err(anyhow!("LEAD first argument must be a column")),
777                };
778
779                let offset = if args.len() > 1 {
781                    match self.evaluate(&args[1], row_index)? {
782                        DataValue::Integer(i) => i as i32,
783                        _ => return Err(anyhow!("LEAD offset must be an integer")),
784                    }
785                } else {
786                    1
787                };
788
789                Ok(context
791                    .get_offset_value(row_index, offset, &column)
792                    .unwrap_or(DataValue::Null))
793            }
794            "ROW_NUMBER" => {
795                Ok(DataValue::Integer(context.get_row_number(row_index) as i64))
797            }
798            "FIRST_VALUE" => {
799                if args.is_empty() {
801                    return Err(anyhow!("FIRST_VALUE requires 1 argument"));
802                }
803
804                let column = match &args[0] {
805                    SqlExpression::Column(col) => col.clone(),
806                    _ => return Err(anyhow!("FIRST_VALUE argument must be a column")),
807                };
808
809                if context.has_frame() {
811                    Ok(context
812                        .get_frame_first_value(row_index, &column)
813                        .unwrap_or(DataValue::Null))
814                } else {
815                    Ok(context
816                        .get_first_value(row_index, &column)
817                        .unwrap_or(DataValue::Null))
818                }
819            }
820            "LAST_VALUE" => {
821                if args.is_empty() {
823                    return Err(anyhow!("LAST_VALUE requires 1 argument"));
824                }
825
826                let column = match &args[0] {
827                    SqlExpression::Column(col) => col.clone(),
828                    _ => return Err(anyhow!("LAST_VALUE argument must be a column")),
829                };
830
831                if context.has_frame() {
833                    Ok(context
834                        .get_frame_last_value(row_index, &column)
835                        .unwrap_or(DataValue::Null))
836                } else {
837                    Ok(context
838                        .get_last_value(row_index, &column)
839                        .unwrap_or(DataValue::Null))
840                }
841            }
842            "SUM" => {
843                if args.is_empty() {
845                    return Err(anyhow!("SUM requires 1 argument"));
846                }
847
848                let column = match &args[0] {
849                    SqlExpression::Column(col) => col.clone(),
850                    _ => return Err(anyhow!("SUM argument must be a column")),
851                };
852
853                if context.has_frame() {
855                    Ok(context
856                        .get_frame_sum(row_index, &column)
857                        .unwrap_or(DataValue::Null))
858                } else {
859                    Ok(context
860                        .get_partition_sum(row_index, &column)
861                        .unwrap_or(DataValue::Null))
862                }
863            }
864            "AVG" => {
865                if args.is_empty() {
867                    return Err(anyhow!("AVG requires 1 argument"));
868                }
869
870                let column = match &args[0] {
871                    SqlExpression::Column(col) => col.clone(),
872                    _ => return Err(anyhow!("AVG argument must be a column")),
873                };
874
875                Ok(context
876                    .get_frame_avg(row_index, &column)
877                    .unwrap_or(DataValue::Null))
878            }
879            "STDDEV" | "STDEV" => {
880                if args.is_empty() {
882                    return Err(anyhow!("STDDEV requires 1 argument"));
883                }
884
885                let column = match &args[0] {
886                    SqlExpression::Column(col) => col.clone(),
887                    _ => return Err(anyhow!("STDDEV argument must be a column")),
888                };
889
890                Ok(context
891                    .get_frame_stddev(row_index, &column)
892                    .unwrap_or(DataValue::Null))
893            }
894            "VARIANCE" | "VAR" => {
895                if args.is_empty() {
897                    return Err(anyhow!("VARIANCE requires 1 argument"));
898                }
899
900                let column = match &args[0] {
901                    SqlExpression::Column(col) => col.clone(),
902                    _ => return Err(anyhow!("VARIANCE argument must be a column")),
903                };
904
905                Ok(context
906                    .get_frame_variance(row_index, &column)
907                    .unwrap_or(DataValue::Null))
908            }
909            "MIN" => {
910                if args.is_empty() {
912                    return Err(anyhow!("MIN requires 1 argument"));
913                }
914
915                let column = match &args[0] {
916                    SqlExpression::Column(col) => col.clone(),
917                    _ => return Err(anyhow!("MIN argument must be a column")),
918                };
919
920                let frame_rows = context.get_frame_rows(row_index);
921                if frame_rows.is_empty() {
922                    return Ok(DataValue::Null);
923                }
924
925                let source_table = context.source();
926                let col_idx = source_table
927                    .get_column_index(&column)
928                    .ok_or_else(|| anyhow!("Column '{}' not found", column))?;
929
930                let mut min_value: Option<DataValue> = None;
931                for &row_idx in &frame_rows {
932                    if let Some(value) = source_table.get_value(row_idx, col_idx) {
933                        if !matches!(value, DataValue::Null) {
934                            match &min_value {
935                                None => min_value = Some(value.clone()),
936                                Some(current_min) => {
937                                    if value < current_min {
938                                        min_value = Some(value.clone());
939                                    }
940                                }
941                            }
942                        }
943                    }
944                }
945
946                Ok(min_value.unwrap_or(DataValue::Null))
947            }
948            "MAX" => {
949                if args.is_empty() {
951                    return Err(anyhow!("MAX requires 1 argument"));
952                }
953
954                let column = match &args[0] {
955                    SqlExpression::Column(col) => col.clone(),
956                    _ => return Err(anyhow!("MAX argument must be a column")),
957                };
958
959                let frame_rows = context.get_frame_rows(row_index);
960                if frame_rows.is_empty() {
961                    return Ok(DataValue::Null);
962                }
963
964                let source_table = context.source();
965                let col_idx = source_table
966                    .get_column_index(&column)
967                    .ok_or_else(|| anyhow!("Column '{}' not found", column))?;
968
969                let mut max_value: Option<DataValue> = None;
970                for &row_idx in &frame_rows {
971                    if let Some(value) = source_table.get_value(row_idx, col_idx) {
972                        if !matches!(value, DataValue::Null) {
973                            match &max_value {
974                                None => max_value = Some(value.clone()),
975                                Some(current_max) => {
976                                    if value > current_max {
977                                        max_value = Some(value.clone());
978                                    }
979                                }
980                            }
981                        }
982                    }
983                }
984
985                Ok(max_value.unwrap_or(DataValue::Null))
986            }
987            "COUNT" => {
988                if args.is_empty() {
992                    if context.has_frame() {
994                        Ok(context
995                            .get_frame_count(row_index, None)
996                            .unwrap_or(DataValue::Null))
997                    } else {
998                        Ok(context
999                            .get_partition_count(row_index, None)
1000                            .unwrap_or(DataValue::Null))
1001                    }
1002                } else {
1003                    let column = match &args[0] {
1005                        SqlExpression::Column(col) => {
1006                            if col == "*" {
1007                                if context.has_frame() {
1009                                    return Ok(context
1010                                        .get_frame_count(row_index, None)
1011                                        .unwrap_or(DataValue::Null));
1012                                } else {
1013                                    return Ok(context
1014                                        .get_partition_count(row_index, None)
1015                                        .unwrap_or(DataValue::Null));
1016                                }
1017                            }
1018                            col.clone()
1019                        }
1020                        SqlExpression::StringLiteral(s) if s == "*" => {
1021                            if context.has_frame() {
1023                                return Ok(context
1024                                    .get_frame_count(row_index, None)
1025                                    .unwrap_or(DataValue::Null));
1026                            } else {
1027                                return Ok(context
1028                                    .get_partition_count(row_index, None)
1029                                    .unwrap_or(DataValue::Null));
1030                            }
1031                        }
1032                        _ => return Err(anyhow!("COUNT argument must be a column or *")),
1033                    };
1034
1035                    if context.has_frame() {
1037                        Ok(context
1038                            .get_frame_count(row_index, Some(&column))
1039                            .unwrap_or(DataValue::Null))
1040                    } else {
1041                        Ok(context
1042                            .get_partition_count(row_index, Some(&column))
1043                            .unwrap_or(DataValue::Null))
1044                    }
1045                }
1046            }
1047            _ => Err(anyhow!("Unknown window function: {}", name)),
1048        }
1049    }
1050
1051    fn evaluate_method_call(
1053        &mut self,
1054        object: &str,
1055        method: &str,
1056        args: &[SqlExpression],
1057        row_index: usize,
1058    ) -> Result<DataValue> {
1059        let col_index = self.table.get_column_index(object).ok_or_else(|| {
1061            let suggestion = self.find_similar_column(object);
1062            match suggestion {
1063                Some(similar) => {
1064                    anyhow!("Column '{}' not found. Did you mean '{}'?", object, similar)
1065                }
1066                None => anyhow!("Column '{}' not found", object),
1067            }
1068        })?;
1069
1070        let cell_value = self.table.get_value(row_index, col_index).cloned();
1071
1072        self.evaluate_method_on_value(
1073            &cell_value.unwrap_or(DataValue::Null),
1074            method,
1075            args,
1076            row_index,
1077        )
1078    }
1079
1080    fn evaluate_method_on_value(
1082        &mut self,
1083        value: &DataValue,
1084        method: &str,
1085        args: &[SqlExpression],
1086        row_index: usize,
1087    ) -> Result<DataValue> {
1088        let function_name = match method.to_lowercase().as_str() {
1093            "trim" => "TRIM",
1094            "trimstart" | "trimbegin" => "TRIMSTART",
1095            "trimend" => "TRIMEND",
1096            "length" | "len" => "LENGTH",
1097            "contains" => "CONTAINS",
1098            "startswith" => "STARTSWITH",
1099            "endswith" => "ENDSWITH",
1100            "indexof" => "INDEXOF",
1101            _ => method, };
1103
1104        if self.function_registry.get(function_name).is_some() {
1106            debug!(
1107                "Proxying method '{}' through function registry as '{}'",
1108                method, function_name
1109            );
1110
1111            let mut func_args = vec![value.clone()];
1113
1114            for arg in args {
1116                func_args.push(self.evaluate(arg, row_index)?);
1117            }
1118
1119            let func = self.function_registry.get(function_name).unwrap();
1121            return func.evaluate(&func_args);
1122        }
1123
1124        Err(anyhow!(
1127            "Method '{}' not found. It should be registered in the function registry.",
1128            method
1129        ))
1130    }
1131
1132    fn evaluate_case_expression(
1134        &mut self,
1135        when_branches: &[crate::sql::recursive_parser::WhenBranch],
1136        else_branch: &Option<Box<SqlExpression>>,
1137        row_index: usize,
1138    ) -> Result<DataValue> {
1139        debug!(
1140            "ArithmeticEvaluator: evaluating CASE expression for row {}",
1141            row_index
1142        );
1143
1144        for branch in when_branches {
1146            let condition_result = self.evaluate_condition_as_bool(&branch.condition, row_index)?;
1148
1149            if condition_result {
1150                debug!("CASE: WHEN condition matched, evaluating result expression");
1151                return self.evaluate(&branch.result, row_index);
1152            }
1153        }
1154
1155        if let Some(else_expr) = else_branch {
1157            debug!("CASE: No WHEN matched, evaluating ELSE expression");
1158            self.evaluate(else_expr, row_index)
1159        } else {
1160            debug!("CASE: No WHEN matched and no ELSE, returning NULL");
1161            Ok(DataValue::Null)
1162        }
1163    }
1164
1165    fn evaluate_simple_case_expression(
1167        &mut self,
1168        expr: &Box<SqlExpression>,
1169        when_branches: &[crate::sql::parser::ast::SimpleWhenBranch],
1170        else_branch: &Option<Box<SqlExpression>>,
1171        row_index: usize,
1172    ) -> Result<DataValue> {
1173        debug!(
1174            "ArithmeticEvaluator: evaluating simple CASE expression for row {}",
1175            row_index
1176        );
1177
1178        let case_value = self.evaluate(expr, row_index)?;
1180        debug!("Simple CASE: evaluated expression to {:?}", case_value);
1181
1182        for branch in when_branches {
1184            let when_value = self.evaluate(&branch.value, row_index)?;
1186
1187            if self.values_equal(&case_value, &when_value)? {
1189                debug!("Simple CASE: WHEN value matched, evaluating result expression");
1190                return self.evaluate(&branch.result, row_index);
1191            }
1192        }
1193
1194        if let Some(else_expr) = else_branch {
1196            debug!("Simple CASE: No WHEN matched, evaluating ELSE expression");
1197            self.evaluate(else_expr, row_index)
1198        } else {
1199            debug!("Simple CASE: No WHEN matched and no ELSE, returning NULL");
1200            Ok(DataValue::Null)
1201        }
1202    }
1203
1204    fn values_equal(&self, left: &DataValue, right: &DataValue) -> Result<bool> {
1206        match (left, right) {
1207            (DataValue::Null, DataValue::Null) => Ok(true),
1208            (DataValue::Null, _) | (_, DataValue::Null) => Ok(false),
1209            (DataValue::Integer(a), DataValue::Integer(b)) => Ok(a == b),
1210            (DataValue::Float(a), DataValue::Float(b)) => Ok((a - b).abs() < f64::EPSILON),
1211            (DataValue::String(a), DataValue::String(b)) => Ok(a == b),
1212            (DataValue::Boolean(a), DataValue::Boolean(b)) => Ok(a == b),
1213            (DataValue::DateTime(a), DataValue::DateTime(b)) => Ok(a == b),
1214            (DataValue::Integer(a), DataValue::Float(b)) => {
1216                Ok((*a as f64 - b).abs() < f64::EPSILON)
1217            }
1218            (DataValue::Float(a), DataValue::Integer(b)) => {
1219                Ok((a - *b as f64).abs() < f64::EPSILON)
1220            }
1221            _ => Ok(false),
1222        }
1223    }
1224
1225    fn evaluate_condition_as_bool(
1227        &mut self,
1228        expr: &SqlExpression,
1229        row_index: usize,
1230    ) -> Result<bool> {
1231        let value = self.evaluate(expr, row_index)?;
1232
1233        match value {
1234            DataValue::Boolean(b) => Ok(b),
1235            DataValue::Integer(i) => Ok(i != 0),
1236            DataValue::Float(f) => Ok(f != 0.0),
1237            DataValue::Null => Ok(false),
1238            DataValue::String(s) => Ok(!s.is_empty()),
1239            DataValue::InternedString(s) => Ok(!s.is_empty()),
1240            _ => Ok(true), }
1242    }
1243}
1244
1245#[cfg(test)]
1246mod tests {
1247    use super::*;
1248    use crate::data::datatable::{DataColumn, DataRow};
1249
1250    fn create_test_table() -> DataTable {
1251        let mut table = DataTable::new("test");
1252        table.add_column(DataColumn::new("a"));
1253        table.add_column(DataColumn::new("b"));
1254        table.add_column(DataColumn::new("c"));
1255
1256        table
1257            .add_row(DataRow::new(vec![
1258                DataValue::Integer(10),
1259                DataValue::Float(2.5),
1260                DataValue::Integer(4),
1261            ]))
1262            .unwrap();
1263
1264        table
1265    }
1266
1267    #[test]
1268    fn test_evaluate_column() {
1269        let table = create_test_table();
1270        let mut evaluator = ArithmeticEvaluator::new(&table);
1271
1272        let expr = SqlExpression::Column("a".to_string());
1273        let result = evaluator.evaluate(&expr, 0).unwrap();
1274        assert_eq!(result, DataValue::Integer(10));
1275    }
1276
1277    #[test]
1278    fn test_evaluate_number_literal() {
1279        let table = create_test_table();
1280        let mut evaluator = ArithmeticEvaluator::new(&table);
1281
1282        let expr = SqlExpression::NumberLiteral("42".to_string());
1283        let result = evaluator.evaluate(&expr, 0).unwrap();
1284        assert_eq!(result, DataValue::Integer(42));
1285
1286        let expr = SqlExpression::NumberLiteral("3.14".to_string());
1287        let result = evaluator.evaluate(&expr, 0).unwrap();
1288        assert_eq!(result, DataValue::Float(3.14));
1289    }
1290
1291    #[test]
1292    fn test_add_values() {
1293        let table = create_test_table();
1294        let mut evaluator = ArithmeticEvaluator::new(&table);
1295
1296        let result = evaluator
1298            .add_values(&DataValue::Integer(5), &DataValue::Integer(3))
1299            .unwrap();
1300        assert_eq!(result, DataValue::Integer(8));
1301
1302        let result = evaluator
1304            .add_values(&DataValue::Integer(5), &DataValue::Float(2.5))
1305            .unwrap();
1306        assert_eq!(result, DataValue::Float(7.5));
1307    }
1308
1309    #[test]
1310    fn test_multiply_values() {
1311        let table = create_test_table();
1312        let mut evaluator = ArithmeticEvaluator::new(&table);
1313
1314        let result = evaluator
1316            .multiply_values(&DataValue::Integer(4), &DataValue::Float(2.5))
1317            .unwrap();
1318        assert_eq!(result, DataValue::Float(10.0));
1319    }
1320
1321    #[test]
1322    fn test_divide_values() {
1323        let table = create_test_table();
1324        let mut evaluator = ArithmeticEvaluator::new(&table);
1325
1326        let result = evaluator
1328            .divide_values(&DataValue::Integer(10), &DataValue::Integer(2))
1329            .unwrap();
1330        assert_eq!(result, DataValue::Integer(5));
1331
1332        let result = evaluator
1334            .divide_values(&DataValue::Integer(10), &DataValue::Integer(3))
1335            .unwrap();
1336        assert_eq!(result, DataValue::Float(10.0 / 3.0));
1337    }
1338
1339    #[test]
1340    fn test_division_by_zero() {
1341        let table = create_test_table();
1342        let mut evaluator = ArithmeticEvaluator::new(&table);
1343
1344        let result = evaluator.divide_values(&DataValue::Integer(10), &DataValue::Integer(0));
1345        assert!(result.is_err());
1346        assert!(result.unwrap_err().to_string().contains("Division by zero"));
1347    }
1348
1349    #[test]
1350    fn test_binary_op_expression() {
1351        let table = create_test_table();
1352        let mut evaluator = ArithmeticEvaluator::new(&table);
1353
1354        let expr = SqlExpression::BinaryOp {
1356            left: Box::new(SqlExpression::Column("a".to_string())),
1357            op: "*".to_string(),
1358            right: Box::new(SqlExpression::Column("b".to_string())),
1359        };
1360
1361        let result = evaluator.evaluate(&expr, 0).unwrap();
1362        assert_eq!(result, DataValue::Float(25.0));
1363    }
1364}