1use crate::config::global::get_date_notation;
2use crate::data::data_view::DataView;
3use crate::data::datatable::{DataTable, DataValue};
4use crate::data::value_comparisons::compare_with_op;
5use crate::sql::aggregates::AggregateRegistry;
6use crate::sql::functions::FunctionRegistry;
7use crate::sql::parser::ast::WindowSpec;
8use crate::sql::recursive_parser::SqlExpression;
9use crate::sql::window_context::WindowContext;
10use crate::sql::window_functions::{ExpressionEvaluator, WindowFunctionRegistry};
11use anyhow::{anyhow, Result};
12use std::collections::{HashMap, HashSet};
13use std::sync::Arc;
14use tracing::debug;
15
16pub struct ArithmeticEvaluator<'a> {
19 table: &'a DataTable,
20 date_notation: String,
21 function_registry: Arc<FunctionRegistry>,
22 aggregate_registry: Arc<AggregateRegistry>,
23 window_function_registry: Arc<WindowFunctionRegistry>,
24 visible_rows: Option<Vec<usize>>, window_contexts: HashMap<String, Arc<WindowContext>>, table_aliases: HashMap<String, String>, }
28
29impl<'a> ArithmeticEvaluator<'a> {
30 #[must_use]
31 pub fn new(table: &'a DataTable) -> Self {
32 Self {
33 table,
34 date_notation: get_date_notation(),
35 function_registry: Arc::new(FunctionRegistry::new()),
36 aggregate_registry: Arc::new(AggregateRegistry::new()),
37 window_function_registry: Arc::new(WindowFunctionRegistry::new()),
38 visible_rows: None,
39 window_contexts: HashMap::new(),
40 table_aliases: HashMap::new(),
41 }
42 }
43
44 #[must_use]
45 pub fn with_date_notation(table: &'a DataTable, date_notation: String) -> Self {
46 Self {
47 table,
48 date_notation,
49 function_registry: Arc::new(FunctionRegistry::new()),
50 aggregate_registry: Arc::new(AggregateRegistry::new()),
51 window_function_registry: Arc::new(WindowFunctionRegistry::new()),
52 visible_rows: None,
53 window_contexts: HashMap::new(),
54 table_aliases: HashMap::new(),
55 }
56 }
57
58 #[must_use]
60 pub fn with_visible_rows(mut self, rows: Vec<usize>) -> Self {
61 self.visible_rows = Some(rows);
62 self
63 }
64
65 #[must_use]
67 pub fn with_table_aliases(mut self, aliases: HashMap<String, String>) -> Self {
68 self.table_aliases = aliases;
69 self
70 }
71
72 #[must_use]
73 pub fn with_date_notation_and_registry(
74 table: &'a DataTable,
75 date_notation: String,
76 function_registry: Arc<FunctionRegistry>,
77 ) -> Self {
78 Self {
79 table,
80 date_notation,
81 function_registry,
82 aggregate_registry: Arc::new(AggregateRegistry::new()),
83 window_function_registry: Arc::new(WindowFunctionRegistry::new()),
84 visible_rows: None,
85 window_contexts: HashMap::new(),
86 table_aliases: HashMap::new(),
87 }
88 }
89
90 fn find_similar_column(&self, name: &str) -> Option<String> {
92 let columns = self.table.column_names();
93 let mut best_match: Option<(String, usize)> = None;
94
95 for col in columns {
96 let distance = self.edit_distance(&col.to_lowercase(), &name.to_lowercase());
97 let max_distance = if name.len() > 10 { 3 } else { 2 };
100 if distance <= max_distance {
101 match &best_match {
102 None => best_match = Some((col, distance)),
103 Some((_, best_dist)) if distance < *best_dist => {
104 best_match = Some((col, distance));
105 }
106 _ => {}
107 }
108 }
109 }
110
111 best_match.map(|(name, _)| name)
112 }
113
114 fn edit_distance(&self, s1: &str, s2: &str) -> usize {
116 crate::sql::functions::string_methods::EditDistanceFunction::calculate_edit_distance(s1, s2)
118 }
119
120 pub fn evaluate(&mut self, expr: &SqlExpression, row_index: usize) -> Result<DataValue> {
122 debug!(
123 "ArithmeticEvaluator: evaluating {:?} for row {}",
124 expr, row_index
125 );
126
127 match expr {
128 SqlExpression::Column(column_name) => self.evaluate_column(column_name, row_index),
129 SqlExpression::StringLiteral(s) => Ok(DataValue::String(s.clone())),
130 SqlExpression::BooleanLiteral(b) => Ok(DataValue::Boolean(*b)),
131 SqlExpression::NumberLiteral(n) => self.evaluate_number_literal(n),
132 SqlExpression::Null => Ok(DataValue::Null),
133 SqlExpression::BinaryOp { left, op, right } => {
134 self.evaluate_binary_op(left, op, right, row_index)
135 }
136 SqlExpression::FunctionCall {
137 name,
138 args,
139 distinct,
140 } => self.evaluate_function_with_distinct(name, args, *distinct, row_index),
141 SqlExpression::WindowFunction {
142 name,
143 args,
144 window_spec,
145 } => self.evaluate_window_function(name, args, window_spec, row_index),
146 SqlExpression::MethodCall {
147 object,
148 method,
149 args,
150 } => self.evaluate_method_call(object, method, args, row_index),
151 SqlExpression::ChainedMethodCall { base, method, args } => {
152 let base_value = self.evaluate(base, row_index)?;
154 self.evaluate_method_on_value(&base_value, method, args, row_index)
155 }
156 SqlExpression::CaseExpression {
157 when_branches,
158 else_branch,
159 } => self.evaluate_case_expression(when_branches, else_branch, row_index),
160 _ => Err(anyhow!(
161 "Unsupported expression type for arithmetic evaluation: {:?}",
162 expr
163 )),
164 }
165 }
166
167 fn evaluate_column(&self, column_name: &str, row_index: usize) -> Result<DataValue> {
169 let resolved_column = if column_name.contains('.') {
171 if let Some(dot_pos) = column_name.rfind('.') {
173 let _table_or_alias = &column_name[..dot_pos];
174 let col_name = &column_name[dot_pos + 1..];
175
176 debug!(
179 "Resolving qualified column: {} -> {}",
180 column_name, col_name
181 );
182 col_name.to_string()
183 } else {
184 column_name.to_string()
185 }
186 } else {
187 column_name.to_string()
188 };
189
190 let col_index = if let Some(idx) = self.table.get_column_index(&resolved_column) {
191 idx
192 } else if resolved_column != column_name {
193 if let Some(idx) = self.table.get_column_index(column_name) {
195 idx
196 } else {
197 let suggestion = self.find_similar_column(&resolved_column);
198 return Err(match suggestion {
199 Some(similar) => anyhow!(
200 "Column '{}' not found. Did you mean '{}'?",
201 column_name,
202 similar
203 ),
204 None => anyhow!("Column '{}' not found", column_name),
205 });
206 }
207 } else {
208 let suggestion = self.find_similar_column(&resolved_column);
209 return Err(match suggestion {
210 Some(similar) => anyhow!(
211 "Column '{}' not found. Did you mean '{}'?",
212 column_name,
213 similar
214 ),
215 None => anyhow!("Column '{}' not found", column_name),
216 });
217 };
218
219 if row_index >= self.table.row_count() {
220 return Err(anyhow!("Row index {} out of bounds", row_index));
221 }
222
223 let row = self
224 .table
225 .get_row(row_index)
226 .ok_or_else(|| anyhow!("Row {} not found", row_index))?;
227
228 let value = row
229 .get(col_index)
230 .ok_or_else(|| anyhow!("Column index {} out of bounds for row", col_index))?;
231
232 Ok(value.clone())
233 }
234
235 fn evaluate_number_literal(&self, number_str: &str) -> Result<DataValue> {
237 if let Ok(int_val) = number_str.parse::<i64>() {
239 return Ok(DataValue::Integer(int_val));
240 }
241
242 if let Ok(float_val) = number_str.parse::<f64>() {
244 return Ok(DataValue::Float(float_val));
245 }
246
247 Err(anyhow!("Invalid number literal: {}", number_str))
248 }
249
250 fn evaluate_binary_op(
252 &mut self,
253 left: &SqlExpression,
254 op: &str,
255 right: &SqlExpression,
256 row_index: usize,
257 ) -> Result<DataValue> {
258 let left_val = self.evaluate(left, row_index)?;
259 let right_val = self.evaluate(right, row_index)?;
260
261 debug!(
262 "ArithmeticEvaluator: {} {} {}",
263 self.format_value(&left_val),
264 op,
265 self.format_value(&right_val)
266 );
267
268 match op {
269 "+" => self.add_values(&left_val, &right_val),
270 "-" => self.subtract_values(&left_val, &right_val),
271 "*" => self.multiply_values(&left_val, &right_val),
272 "/" => self.divide_values(&left_val, &right_val),
273 "%" => {
274 let args = vec![left.clone(), right.clone()];
276 self.evaluate_function("MOD", &args, row_index)
277 }
278 ">" | "<" | ">=" | "<=" | "=" | "!=" | "<>" => {
281 let result = compare_with_op(&left_val, &right_val, op, false);
282 Ok(DataValue::Boolean(result))
283 }
284 "IS NULL" => Ok(DataValue::Boolean(matches!(left_val, DataValue::Null))),
286 "IS NOT NULL" => Ok(DataValue::Boolean(!matches!(left_val, DataValue::Null))),
287 "AND" => {
289 let left_bool = self.to_bool(&left_val)?;
290 let right_bool = self.to_bool(&right_val)?;
291 Ok(DataValue::Boolean(left_bool && right_bool))
292 }
293 "OR" => {
294 let left_bool = self.to_bool(&left_val)?;
295 let right_bool = self.to_bool(&right_val)?;
296 Ok(DataValue::Boolean(left_bool || right_bool))
297 }
298 _ => Err(anyhow!("Unsupported arithmetic operator: {}", op)),
299 }
300 }
301
302 fn add_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
304 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
306 return Ok(DataValue::Null);
307 }
308
309 match (left, right) {
310 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a + b)),
311 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 + b)),
312 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a + *b as f64)),
313 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a + b)),
314 _ => Err(anyhow!("Cannot add {:?} and {:?}", left, right)),
315 }
316 }
317
318 fn subtract_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
320 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
322 return Ok(DataValue::Null);
323 }
324
325 match (left, right) {
326 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a - b)),
327 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 - b)),
328 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a - *b as f64)),
329 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a - b)),
330 _ => Err(anyhow!("Cannot subtract {:?} and {:?}", left, right)),
331 }
332 }
333
334 fn multiply_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
336 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
338 return Ok(DataValue::Null);
339 }
340
341 match (left, right) {
342 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a * b)),
343 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 * b)),
344 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a * *b as f64)),
345 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a * b)),
346 _ => Err(anyhow!("Cannot multiply {:?} and {:?}", left, right)),
347 }
348 }
349
350 fn divide_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
352 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
354 return Ok(DataValue::Null);
355 }
356
357 let is_zero = match right {
359 DataValue::Integer(0) => true,
360 DataValue::Float(f) if *f == 0.0 => true, _ => false,
362 };
363
364 if is_zero {
365 return Err(anyhow!("Division by zero"));
366 }
367
368 match (left, right) {
369 (DataValue::Integer(a), DataValue::Integer(b)) => {
370 if a % b == 0 {
372 Ok(DataValue::Integer(a / b))
373 } else {
374 Ok(DataValue::Float(*a as f64 / *b as f64))
375 }
376 }
377 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 / b)),
378 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a / *b as f64)),
379 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a / b)),
380 _ => Err(anyhow!("Cannot divide {:?} and {:?}", left, right)),
381 }
382 }
383
384 fn format_value(&self, value: &DataValue) -> String {
386 match value {
387 DataValue::Integer(i) => i.to_string(),
388 DataValue::Float(f) => f.to_string(),
389 DataValue::String(s) => format!("'{s}'"),
390 _ => format!("{value:?}"),
391 }
392 }
393
394 fn to_bool(&self, value: &DataValue) -> Result<bool> {
396 match value {
397 DataValue::Boolean(b) => Ok(*b),
398 DataValue::Integer(i) => Ok(*i != 0),
399 DataValue::Float(f) => Ok(*f != 0.0),
400 DataValue::Null => Ok(false),
401 _ => Err(anyhow!("Cannot convert {:?} to boolean", value)),
402 }
403 }
404
405 fn evaluate_function_with_distinct(
407 &mut self,
408 name: &str,
409 args: &[SqlExpression],
410 distinct: bool,
411 row_index: usize,
412 ) -> Result<DataValue> {
413 if distinct {
415 let name_upper = name.to_uppercase();
416
417 if self.aggregate_registry.is_aggregate(&name_upper) {
419 return self.evaluate_aggregate_with_distinct(&name_upper, args, row_index);
420 } else {
421 return Err(anyhow!(
422 "DISTINCT can only be used with aggregate functions"
423 ));
424 }
425 }
426
427 self.evaluate_function(name, args, row_index)
429 }
430
431 fn evaluate_aggregate_with_distinct(
432 &mut self,
433 name: &str,
434 args: &[SqlExpression],
435 _row_index: usize,
436 ) -> Result<DataValue> {
437 let name_upper = name.to_uppercase();
438
439 if self.aggregate_registry.get(&name_upper).is_some() {
441 let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
443 visible.clone()
444 } else {
445 (0..self.table.rows.len()).collect()
446 };
447
448 if name_upper == "STRING_AGG" && args.len() >= 2 {
450 let mut state = crate::sql::aggregates::AggregateState::StringAgg(
452 if args.len() >= 2 {
454 let separator = self.evaluate(&args[1], 0)?; match separator {
456 DataValue::String(s) => crate::sql::aggregates::StringAggState::new(&s),
457 DataValue::InternedString(s) => {
458 crate::sql::aggregates::StringAggState::new(&s)
459 }
460 _ => crate::sql::aggregates::StringAggState::new(","), }
462 } else {
463 crate::sql::aggregates::StringAggState::new(",")
464 },
465 );
466
467 let mut seen_values = HashSet::new();
470
471 for &row_idx in &rows_to_process {
472 let value = self.evaluate(&args[0], row_idx)?;
473
474 if !seen_values.insert(value.clone()) {
476 continue; }
478
479 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
481 agg_func.accumulate(&mut state, &value)?;
482 }
483
484 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
486 return Ok(agg_func.finalize(state));
487 }
488
489 let mut vals = Vec::new();
492 for &row_idx in &rows_to_process {
493 if !args.is_empty() {
494 let value = self.evaluate(&args[0], row_idx)?;
495 vals.push(value);
496 }
497 }
498
499 let mut seen = HashSet::new();
501 let mut unique_values = Vec::new();
502 for value in vals {
503 if seen.insert(value.clone()) {
504 unique_values.push(value);
505 }
506 }
507
508 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
510 let mut state = agg_func.init();
511
512 for value in &unique_values {
514 agg_func.accumulate(&mut state, value)?;
515 }
516
517 return Ok(agg_func.finalize(state));
518 }
519
520 Err(anyhow!("Unknown aggregate function: {}", name))
521 }
522
523 fn evaluate_function(
524 &mut self,
525 name: &str,
526 args: &[SqlExpression],
527 row_index: usize,
528 ) -> Result<DataValue> {
529 let name_upper = name.to_uppercase();
531
532 if name_upper == "COUNT" && args.len() == 1 {
534 match &args[0] {
535 SqlExpression::Column(col) if col == "*" => {
536 let count = if let Some(ref visible) = self.visible_rows {
538 visible.len() as i64
539 } else {
540 self.table.rows.len() as i64
541 };
542 return Ok(DataValue::Integer(count));
543 }
544 SqlExpression::StringLiteral(s) if s == "*" => {
545 let count = if let Some(ref visible) = self.visible_rows {
547 visible.len() as i64
548 } else {
549 self.table.rows.len() as i64
550 };
551 return Ok(DataValue::Integer(count));
552 }
553 _ => {
554 }
556 }
557 }
558
559 if self.aggregate_registry.get(&name_upper).is_some() {
561 let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
563 visible.clone()
564 } else {
565 (0..self.table.rows.len()).collect()
566 };
567
568 if name_upper == "STRING_AGG" && args.len() >= 2 {
570 let mut state = crate::sql::aggregates::AggregateState::StringAgg(
572 if args.len() >= 2 {
574 let separator = self.evaluate(&args[1], 0)?; match separator {
576 DataValue::String(s) => crate::sql::aggregates::StringAggState::new(&s),
577 DataValue::InternedString(s) => {
578 crate::sql::aggregates::StringAggState::new(&s)
579 }
580 _ => crate::sql::aggregates::StringAggState::new(","), }
582 } else {
583 crate::sql::aggregates::StringAggState::new(",")
584 },
585 );
586
587 for &row_idx in &rows_to_process {
589 let value = self.evaluate(&args[0], row_idx)?;
590 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
592 agg_func.accumulate(&mut state, &value)?;
593 }
594
595 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
597 return Ok(agg_func.finalize(state));
598 }
599
600 let values = if !args.is_empty()
602 && !(args.len() == 1 && matches!(&args[0], SqlExpression::Column(c) if c == "*"))
603 {
604 let mut vals = Vec::new();
606 for &row_idx in &rows_to_process {
607 let value = self.evaluate(&args[0], row_idx)?;
608 vals.push(value);
609 }
610 Some(vals)
611 } else {
612 None
613 };
614
615 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
617 let mut state = agg_func.init();
618
619 if let Some(values) = values {
620 for value in &values {
622 agg_func.accumulate(&mut state, value)?;
623 }
624 } else {
625 for _ in &rows_to_process {
627 agg_func.accumulate(&mut state, &DataValue::Integer(1))?;
628 }
629 }
630
631 return Ok(agg_func.finalize(state));
632 }
633
634 if self.function_registry.get(name).is_some() {
636 let mut evaluated_args = Vec::new();
638 for arg in args {
639 evaluated_args.push(self.evaluate(arg, row_index)?);
640 }
641
642 let func = self.function_registry.get(name).unwrap();
644 return func.evaluate(&evaluated_args);
645 }
646
647 Err(anyhow!("Unknown function: {}", name))
649 }
650
651 fn get_or_create_window_context(&mut self, spec: &WindowSpec) -> Result<Arc<WindowContext>> {
653 let key = format!("{:?}", spec);
655
656 if let Some(context) = self.window_contexts.get(&key) {
657 return Ok(Arc::clone(context));
658 }
659
660 let data_view = if let Some(ref _visible_rows) = self.visible_rows {
662 let view = DataView::new(Arc::new(self.table.clone()));
664 view
667 } else {
668 DataView::new(Arc::new(self.table.clone()))
669 };
670
671 let context = WindowContext::new_with_spec(Arc::new(data_view), spec.clone())?;
673
674 let context = Arc::new(context);
675 self.window_contexts.insert(key, Arc::clone(&context));
676 Ok(context)
677 }
678
679 fn evaluate_window_function(
681 &mut self,
682 name: &str,
683 args: &[SqlExpression],
684 spec: &WindowSpec,
685 row_index: usize,
686 ) -> Result<DataValue> {
687 let name_upper = name.to_uppercase();
688
689 debug!("Looking for window function {} in registry", name_upper);
691 if let Some(window_fn_arc) = self.window_function_registry.get(&name_upper) {
692 debug!("Found window function {} in registry", name_upper);
693
694 let window_fn = window_fn_arc.as_ref();
696
697 window_fn.validate_args(args)?;
699
700 let transformed_spec = window_fn.transform_window_spec(spec, args)?;
702
703 let context = self.get_or_create_window_context(&transformed_spec)?;
705
706 struct EvaluatorAdapter<'a, 'b> {
708 evaluator: &'a mut ArithmeticEvaluator<'b>,
709 row_index: usize,
710 }
711
712 impl<'a, 'b> ExpressionEvaluator for EvaluatorAdapter<'a, 'b> {
713 fn evaluate(
714 &mut self,
715 expr: &SqlExpression,
716 _row_index: usize,
717 ) -> Result<DataValue> {
718 self.evaluator.evaluate(expr, self.row_index)
719 }
720 }
721
722 let mut adapter = EvaluatorAdapter {
723 evaluator: self,
724 row_index,
725 };
726
727 return window_fn.compute(&context, row_index, args, &mut adapter);
729 }
730
731 let context = self.get_or_create_window_context(spec)?;
733
734 match name_upper.as_str() {
735 "LAG" => {
736 if args.is_empty() {
738 return Err(anyhow!("LAG requires at least 1 argument"));
739 }
740
741 let column = match &args[0] {
743 SqlExpression::Column(col) => col.clone(),
744 _ => return Err(anyhow!("LAG first argument must be a column")),
745 };
746
747 let offset = if args.len() > 1 {
749 match self.evaluate(&args[1], row_index)? {
750 DataValue::Integer(i) => i as i32,
751 _ => return Err(anyhow!("LAG offset must be an integer")),
752 }
753 } else {
754 1
755 };
756
757 Ok(context
759 .get_offset_value(row_index, -offset, &column)
760 .unwrap_or(DataValue::Null))
761 }
762 "LEAD" => {
763 if args.is_empty() {
765 return Err(anyhow!("LEAD requires at least 1 argument"));
766 }
767
768 let column = match &args[0] {
770 SqlExpression::Column(col) => col.clone(),
771 _ => return Err(anyhow!("LEAD first argument must be a column")),
772 };
773
774 let offset = if args.len() > 1 {
776 match self.evaluate(&args[1], row_index)? {
777 DataValue::Integer(i) => i as i32,
778 _ => return Err(anyhow!("LEAD offset must be an integer")),
779 }
780 } else {
781 1
782 };
783
784 Ok(context
786 .get_offset_value(row_index, offset, &column)
787 .unwrap_or(DataValue::Null))
788 }
789 "ROW_NUMBER" => {
790 Ok(DataValue::Integer(context.get_row_number(row_index) as i64))
792 }
793 "FIRST_VALUE" => {
794 if args.is_empty() {
796 return Err(anyhow!("FIRST_VALUE requires 1 argument"));
797 }
798
799 let column = match &args[0] {
800 SqlExpression::Column(col) => col.clone(),
801 _ => return Err(anyhow!("FIRST_VALUE argument must be a column")),
802 };
803
804 if context.has_frame() {
806 Ok(context
807 .get_frame_first_value(row_index, &column)
808 .unwrap_or(DataValue::Null))
809 } else {
810 Ok(context
811 .get_first_value(row_index, &column)
812 .unwrap_or(DataValue::Null))
813 }
814 }
815 "LAST_VALUE" => {
816 if args.is_empty() {
818 return Err(anyhow!("LAST_VALUE requires 1 argument"));
819 }
820
821 let column = match &args[0] {
822 SqlExpression::Column(col) => col.clone(),
823 _ => return Err(anyhow!("LAST_VALUE argument must be a column")),
824 };
825
826 if context.has_frame() {
828 Ok(context
829 .get_frame_last_value(row_index, &column)
830 .unwrap_or(DataValue::Null))
831 } else {
832 Ok(context
833 .get_last_value(row_index, &column)
834 .unwrap_or(DataValue::Null))
835 }
836 }
837 "SUM" => {
838 if args.is_empty() {
840 return Err(anyhow!("SUM requires 1 argument"));
841 }
842
843 let column = match &args[0] {
844 SqlExpression::Column(col) => col.clone(),
845 _ => return Err(anyhow!("SUM argument must be a column")),
846 };
847
848 if context.has_frame() {
850 Ok(context
851 .get_frame_sum(row_index, &column)
852 .unwrap_or(DataValue::Null))
853 } else {
854 Ok(context
855 .get_partition_sum(row_index, &column)
856 .unwrap_or(DataValue::Null))
857 }
858 }
859 "AVG" => {
860 if args.is_empty() {
862 return Err(anyhow!("AVG requires 1 argument"));
863 }
864
865 let column = match &args[0] {
866 SqlExpression::Column(col) => col.clone(),
867 _ => return Err(anyhow!("AVG argument must be a column")),
868 };
869
870 Ok(context
871 .get_frame_avg(row_index, &column)
872 .unwrap_or(DataValue::Null))
873 }
874 "STDDEV" | "STDEV" => {
875 if args.is_empty() {
877 return Err(anyhow!("STDDEV requires 1 argument"));
878 }
879
880 let column = match &args[0] {
881 SqlExpression::Column(col) => col.clone(),
882 _ => return Err(anyhow!("STDDEV argument must be a column")),
883 };
884
885 Ok(context
886 .get_frame_stddev(row_index, &column)
887 .unwrap_or(DataValue::Null))
888 }
889 "VARIANCE" | "VAR" => {
890 if args.is_empty() {
892 return Err(anyhow!("VARIANCE requires 1 argument"));
893 }
894
895 let column = match &args[0] {
896 SqlExpression::Column(col) => col.clone(),
897 _ => return Err(anyhow!("VARIANCE argument must be a column")),
898 };
899
900 Ok(context
901 .get_frame_variance(row_index, &column)
902 .unwrap_or(DataValue::Null))
903 }
904 "MIN" => {
905 if args.is_empty() {
907 return Err(anyhow!("MIN requires 1 argument"));
908 }
909
910 let column = match &args[0] {
911 SqlExpression::Column(col) => col.clone(),
912 _ => return Err(anyhow!("MIN argument must be a column")),
913 };
914
915 let frame_rows = context.get_frame_rows(row_index);
916 if frame_rows.is_empty() {
917 return Ok(DataValue::Null);
918 }
919
920 let source_table = context.source();
921 let col_idx = source_table
922 .get_column_index(&column)
923 .ok_or_else(|| anyhow!("Column '{}' not found", column))?;
924
925 let mut min_value: Option<DataValue> = None;
926 for &row_idx in &frame_rows {
927 if let Some(value) = source_table.get_value(row_idx, col_idx) {
928 if !matches!(value, DataValue::Null) {
929 match &min_value {
930 None => min_value = Some(value.clone()),
931 Some(current_min) => {
932 if value < current_min {
933 min_value = Some(value.clone());
934 }
935 }
936 }
937 }
938 }
939 }
940
941 Ok(min_value.unwrap_or(DataValue::Null))
942 }
943 "MAX" => {
944 if args.is_empty() {
946 return Err(anyhow!("MAX requires 1 argument"));
947 }
948
949 let column = match &args[0] {
950 SqlExpression::Column(col) => col.clone(),
951 _ => return Err(anyhow!("MAX argument must be a column")),
952 };
953
954 let frame_rows = context.get_frame_rows(row_index);
955 if frame_rows.is_empty() {
956 return Ok(DataValue::Null);
957 }
958
959 let source_table = context.source();
960 let col_idx = source_table
961 .get_column_index(&column)
962 .ok_or_else(|| anyhow!("Column '{}' not found", column))?;
963
964 let mut max_value: Option<DataValue> = None;
965 for &row_idx in &frame_rows {
966 if let Some(value) = source_table.get_value(row_idx, col_idx) {
967 if !matches!(value, DataValue::Null) {
968 match &max_value {
969 None => max_value = Some(value.clone()),
970 Some(current_max) => {
971 if value > current_max {
972 max_value = Some(value.clone());
973 }
974 }
975 }
976 }
977 }
978 }
979
980 Ok(max_value.unwrap_or(DataValue::Null))
981 }
982 "COUNT" => {
983 if args.is_empty() {
987 if context.has_frame() {
989 Ok(context
990 .get_frame_count(row_index, None)
991 .unwrap_or(DataValue::Null))
992 } else {
993 Ok(context
994 .get_partition_count(row_index, None)
995 .unwrap_or(DataValue::Null))
996 }
997 } else {
998 let column = match &args[0] {
1000 SqlExpression::Column(col) => {
1001 if col == "*" {
1002 if context.has_frame() {
1004 return Ok(context
1005 .get_frame_count(row_index, None)
1006 .unwrap_or(DataValue::Null));
1007 } else {
1008 return Ok(context
1009 .get_partition_count(row_index, None)
1010 .unwrap_or(DataValue::Null));
1011 }
1012 }
1013 col.clone()
1014 }
1015 SqlExpression::StringLiteral(s) if s == "*" => {
1016 if context.has_frame() {
1018 return Ok(context
1019 .get_frame_count(row_index, None)
1020 .unwrap_or(DataValue::Null));
1021 } else {
1022 return Ok(context
1023 .get_partition_count(row_index, None)
1024 .unwrap_or(DataValue::Null));
1025 }
1026 }
1027 _ => return Err(anyhow!("COUNT argument must be a column or *")),
1028 };
1029
1030 if context.has_frame() {
1032 Ok(context
1033 .get_frame_count(row_index, Some(&column))
1034 .unwrap_or(DataValue::Null))
1035 } else {
1036 Ok(context
1037 .get_partition_count(row_index, Some(&column))
1038 .unwrap_or(DataValue::Null))
1039 }
1040 }
1041 }
1042 _ => Err(anyhow!("Unknown window function: {}", name)),
1043 }
1044 }
1045
1046 fn evaluate_method_call(
1048 &mut self,
1049 object: &str,
1050 method: &str,
1051 args: &[SqlExpression],
1052 row_index: usize,
1053 ) -> Result<DataValue> {
1054 let col_index = self.table.get_column_index(object).ok_or_else(|| {
1056 let suggestion = self.find_similar_column(object);
1057 match suggestion {
1058 Some(similar) => {
1059 anyhow!("Column '{}' not found. Did you mean '{}'?", object, similar)
1060 }
1061 None => anyhow!("Column '{}' not found", object),
1062 }
1063 })?;
1064
1065 let cell_value = self.table.get_value(row_index, col_index).cloned();
1066
1067 self.evaluate_method_on_value(
1068 &cell_value.unwrap_or(DataValue::Null),
1069 method,
1070 args,
1071 row_index,
1072 )
1073 }
1074
1075 fn evaluate_method_on_value(
1077 &mut self,
1078 value: &DataValue,
1079 method: &str,
1080 args: &[SqlExpression],
1081 row_index: usize,
1082 ) -> Result<DataValue> {
1083 let function_name = match method.to_lowercase().as_str() {
1088 "trim" => "TRIM",
1089 "trimstart" | "trimbegin" => "TRIMSTART",
1090 "trimend" => "TRIMEND",
1091 "length" | "len" => "LENGTH",
1092 "contains" => "CONTAINS",
1093 "startswith" => "STARTSWITH",
1094 "endswith" => "ENDSWITH",
1095 "indexof" => "INDEXOF",
1096 _ => method, };
1098
1099 if self.function_registry.get(function_name).is_some() {
1101 debug!(
1102 "Proxying method '{}' through function registry as '{}'",
1103 method, function_name
1104 );
1105
1106 let mut func_args = vec![value.clone()];
1108
1109 for arg in args {
1111 func_args.push(self.evaluate(arg, row_index)?);
1112 }
1113
1114 let func = self.function_registry.get(function_name).unwrap();
1116 return func.evaluate(&func_args);
1117 }
1118
1119 Err(anyhow!(
1122 "Method '{}' not found. It should be registered in the function registry.",
1123 method
1124 ))
1125 }
1126
1127 fn evaluate_case_expression(
1129 &mut self,
1130 when_branches: &[crate::sql::recursive_parser::WhenBranch],
1131 else_branch: &Option<Box<SqlExpression>>,
1132 row_index: usize,
1133 ) -> Result<DataValue> {
1134 debug!(
1135 "ArithmeticEvaluator: evaluating CASE expression for row {}",
1136 row_index
1137 );
1138
1139 for branch in when_branches {
1141 let condition_result = self.evaluate_condition_as_bool(&branch.condition, row_index)?;
1143
1144 if condition_result {
1145 debug!("CASE: WHEN condition matched, evaluating result expression");
1146 return self.evaluate(&branch.result, row_index);
1147 }
1148 }
1149
1150 if let Some(else_expr) = else_branch {
1152 debug!("CASE: No WHEN matched, evaluating ELSE expression");
1153 self.evaluate(else_expr, row_index)
1154 } else {
1155 debug!("CASE: No WHEN matched and no ELSE, returning NULL");
1156 Ok(DataValue::Null)
1157 }
1158 }
1159
1160 fn evaluate_condition_as_bool(
1162 &mut self,
1163 expr: &SqlExpression,
1164 row_index: usize,
1165 ) -> Result<bool> {
1166 let value = self.evaluate(expr, row_index)?;
1167
1168 match value {
1169 DataValue::Boolean(b) => Ok(b),
1170 DataValue::Integer(i) => Ok(i != 0),
1171 DataValue::Float(f) => Ok(f != 0.0),
1172 DataValue::Null => Ok(false),
1173 DataValue::String(s) => Ok(!s.is_empty()),
1174 DataValue::InternedString(s) => Ok(!s.is_empty()),
1175 _ => Ok(true), }
1177 }
1178}
1179
1180#[cfg(test)]
1181mod tests {
1182 use super::*;
1183 use crate::data::datatable::{DataColumn, DataRow};
1184
1185 fn create_test_table() -> DataTable {
1186 let mut table = DataTable::new("test");
1187 table.add_column(DataColumn::new("a"));
1188 table.add_column(DataColumn::new("b"));
1189 table.add_column(DataColumn::new("c"));
1190
1191 table
1192 .add_row(DataRow::new(vec![
1193 DataValue::Integer(10),
1194 DataValue::Float(2.5),
1195 DataValue::Integer(4),
1196 ]))
1197 .unwrap();
1198
1199 table
1200 }
1201
1202 #[test]
1203 fn test_evaluate_column() {
1204 let table = create_test_table();
1205 let mut evaluator = ArithmeticEvaluator::new(&table);
1206
1207 let expr = SqlExpression::Column("a".to_string());
1208 let result = evaluator.evaluate(&expr, 0).unwrap();
1209 assert_eq!(result, DataValue::Integer(10));
1210 }
1211
1212 #[test]
1213 fn test_evaluate_number_literal() {
1214 let table = create_test_table();
1215 let mut evaluator = ArithmeticEvaluator::new(&table);
1216
1217 let expr = SqlExpression::NumberLiteral("42".to_string());
1218 let result = evaluator.evaluate(&expr, 0).unwrap();
1219 assert_eq!(result, DataValue::Integer(42));
1220
1221 let expr = SqlExpression::NumberLiteral("3.14".to_string());
1222 let result = evaluator.evaluate(&expr, 0).unwrap();
1223 assert_eq!(result, DataValue::Float(3.14));
1224 }
1225
1226 #[test]
1227 fn test_add_values() {
1228 let table = create_test_table();
1229 let mut evaluator = ArithmeticEvaluator::new(&table);
1230
1231 let result = evaluator
1233 .add_values(&DataValue::Integer(5), &DataValue::Integer(3))
1234 .unwrap();
1235 assert_eq!(result, DataValue::Integer(8));
1236
1237 let result = evaluator
1239 .add_values(&DataValue::Integer(5), &DataValue::Float(2.5))
1240 .unwrap();
1241 assert_eq!(result, DataValue::Float(7.5));
1242 }
1243
1244 #[test]
1245 fn test_multiply_values() {
1246 let table = create_test_table();
1247 let mut evaluator = ArithmeticEvaluator::new(&table);
1248
1249 let result = evaluator
1251 .multiply_values(&DataValue::Integer(4), &DataValue::Float(2.5))
1252 .unwrap();
1253 assert_eq!(result, DataValue::Float(10.0));
1254 }
1255
1256 #[test]
1257 fn test_divide_values() {
1258 let table = create_test_table();
1259 let mut evaluator = ArithmeticEvaluator::new(&table);
1260
1261 let result = evaluator
1263 .divide_values(&DataValue::Integer(10), &DataValue::Integer(2))
1264 .unwrap();
1265 assert_eq!(result, DataValue::Integer(5));
1266
1267 let result = evaluator
1269 .divide_values(&DataValue::Integer(10), &DataValue::Integer(3))
1270 .unwrap();
1271 assert_eq!(result, DataValue::Float(10.0 / 3.0));
1272 }
1273
1274 #[test]
1275 fn test_division_by_zero() {
1276 let table = create_test_table();
1277 let mut evaluator = ArithmeticEvaluator::new(&table);
1278
1279 let result = evaluator.divide_values(&DataValue::Integer(10), &DataValue::Integer(0));
1280 assert!(result.is_err());
1281 assert!(result.unwrap_err().to_string().contains("Division by zero"));
1282 }
1283
1284 #[test]
1285 fn test_binary_op_expression() {
1286 let table = create_test_table();
1287 let mut evaluator = ArithmeticEvaluator::new(&table);
1288
1289 let expr = SqlExpression::BinaryOp {
1291 left: Box::new(SqlExpression::Column("a".to_string())),
1292 op: "*".to_string(),
1293 right: Box::new(SqlExpression::Column("b".to_string())),
1294 };
1295
1296 let result = evaluator.evaluate(&expr, 0).unwrap();
1297 assert_eq!(result, DataValue::Float(25.0));
1298 }
1299}