1use crate::config::global::get_date_notation;
2use crate::data::data_view::DataView;
3use crate::data::datatable::{DataTable, DataValue};
4use crate::data::value_comparisons::compare_with_op;
5use crate::sql::aggregates::AggregateRegistry;
6use crate::sql::functions::FunctionRegistry;
7use crate::sql::recursive_parser::{SqlExpression, WindowSpec};
8use crate::sql::window_context::WindowContext;
9use anyhow::{anyhow, Result};
10use std::collections::HashMap;
11use std::sync::Arc;
12use tracing::debug;
13
14pub struct ArithmeticEvaluator<'a> {
17 table: &'a DataTable,
18 date_notation: String,
19 function_registry: Arc<FunctionRegistry>,
20 aggregate_registry: Arc<AggregateRegistry>,
21 visible_rows: Option<Vec<usize>>, window_contexts: HashMap<String, Arc<WindowContext>>, table_aliases: HashMap<String, String>, }
25
26impl<'a> ArithmeticEvaluator<'a> {
27 #[must_use]
28 pub fn new(table: &'a DataTable) -> Self {
29 Self {
30 table,
31 date_notation: get_date_notation(),
32 function_registry: Arc::new(FunctionRegistry::new()),
33 aggregate_registry: Arc::new(AggregateRegistry::new()),
34 visible_rows: None,
35 window_contexts: HashMap::new(),
36 table_aliases: HashMap::new(),
37 }
38 }
39
40 #[must_use]
41 pub fn with_date_notation(table: &'a DataTable, date_notation: String) -> Self {
42 Self {
43 table,
44 date_notation,
45 function_registry: Arc::new(FunctionRegistry::new()),
46 aggregate_registry: Arc::new(AggregateRegistry::new()),
47 visible_rows: None,
48 window_contexts: HashMap::new(),
49 table_aliases: HashMap::new(),
50 }
51 }
52
53 #[must_use]
55 pub fn with_visible_rows(mut self, rows: Vec<usize>) -> Self {
56 self.visible_rows = Some(rows);
57 self
58 }
59
60 #[must_use]
62 pub fn with_table_aliases(mut self, aliases: HashMap<String, String>) -> Self {
63 self.table_aliases = aliases;
64 self
65 }
66
67 #[must_use]
68 pub fn with_date_notation_and_registry(
69 table: &'a DataTable,
70 date_notation: String,
71 function_registry: Arc<FunctionRegistry>,
72 ) -> Self {
73 Self {
74 table,
75 date_notation,
76 function_registry,
77 aggregate_registry: Arc::new(AggregateRegistry::new()),
78 visible_rows: None,
79 window_contexts: HashMap::new(),
80 table_aliases: HashMap::new(),
81 }
82 }
83
84 fn find_similar_column(&self, name: &str) -> Option<String> {
86 let columns = self.table.column_names();
87 let mut best_match: Option<(String, usize)> = None;
88
89 for col in columns {
90 let distance = self.edit_distance(&col.to_lowercase(), &name.to_lowercase());
91 let max_distance = if name.len() > 10 { 3 } else { 2 };
94 if distance <= max_distance {
95 match &best_match {
96 None => best_match = Some((col, distance)),
97 Some((_, best_dist)) if distance < *best_dist => {
98 best_match = Some((col, distance));
99 }
100 _ => {}
101 }
102 }
103 }
104
105 best_match.map(|(name, _)| name)
106 }
107
108 fn edit_distance(&self, s1: &str, s2: &str) -> usize {
110 crate::sql::functions::string_methods::EditDistanceFunction::calculate_edit_distance(s1, s2)
112 }
113
114 pub fn evaluate(&mut self, expr: &SqlExpression, row_index: usize) -> Result<DataValue> {
116 debug!(
117 "ArithmeticEvaluator: evaluating {:?} for row {}",
118 expr, row_index
119 );
120
121 match expr {
122 SqlExpression::Column(column_name) => self.evaluate_column(column_name, row_index),
123 SqlExpression::StringLiteral(s) => Ok(DataValue::String(s.clone())),
124 SqlExpression::BooleanLiteral(b) => Ok(DataValue::Boolean(*b)),
125 SqlExpression::NumberLiteral(n) => self.evaluate_number_literal(n),
126 SqlExpression::Null => Ok(DataValue::Null),
127 SqlExpression::BinaryOp { left, op, right } => {
128 self.evaluate_binary_op(left, op, right, row_index)
129 }
130 SqlExpression::FunctionCall {
131 name,
132 args,
133 distinct,
134 } => self.evaluate_function_with_distinct(name, args, *distinct, row_index),
135 SqlExpression::WindowFunction {
136 name,
137 args,
138 window_spec,
139 } => self.evaluate_window_function(name, args, window_spec, row_index),
140 SqlExpression::MethodCall {
141 object,
142 method,
143 args,
144 } => self.evaluate_method_call(object, method, args, row_index),
145 SqlExpression::ChainedMethodCall { base, method, args } => {
146 let base_value = self.evaluate(base, row_index)?;
148 self.evaluate_method_on_value(&base_value, method, args, row_index)
149 }
150 SqlExpression::CaseExpression {
151 when_branches,
152 else_branch,
153 } => self.evaluate_case_expression(when_branches, else_branch, row_index),
154 _ => Err(anyhow!(
155 "Unsupported expression type for arithmetic evaluation: {:?}",
156 expr
157 )),
158 }
159 }
160
161 fn evaluate_column(&self, column_name: &str, row_index: usize) -> Result<DataValue> {
163 let resolved_column = if column_name.contains('.') {
165 if let Some(dot_pos) = column_name.rfind('.') {
167 let _table_or_alias = &column_name[..dot_pos];
168 let col_name = &column_name[dot_pos + 1..];
169
170 debug!(
173 "Resolving qualified column: {} -> {}",
174 column_name, col_name
175 );
176 col_name.to_string()
177 } else {
178 column_name.to_string()
179 }
180 } else {
181 column_name.to_string()
182 };
183
184 let col_index = if let Some(idx) = self.table.get_column_index(&resolved_column) {
185 idx
186 } else if resolved_column != column_name {
187 if let Some(idx) = self.table.get_column_index(column_name) {
189 idx
190 } else {
191 let suggestion = self.find_similar_column(&resolved_column);
192 return Err(match suggestion {
193 Some(similar) => anyhow!(
194 "Column '{}' not found. Did you mean '{}'?",
195 column_name,
196 similar
197 ),
198 None => anyhow!("Column '{}' not found", column_name),
199 });
200 }
201 } else {
202 let suggestion = self.find_similar_column(&resolved_column);
203 return Err(match suggestion {
204 Some(similar) => anyhow!(
205 "Column '{}' not found. Did you mean '{}'?",
206 column_name,
207 similar
208 ),
209 None => anyhow!("Column '{}' not found", column_name),
210 });
211 };
212
213 if row_index >= self.table.row_count() {
214 return Err(anyhow!("Row index {} out of bounds", row_index));
215 }
216
217 let row = self
218 .table
219 .get_row(row_index)
220 .ok_or_else(|| anyhow!("Row {} not found", row_index))?;
221
222 let value = row
223 .get(col_index)
224 .ok_or_else(|| anyhow!("Column index {} out of bounds for row", col_index))?;
225
226 Ok(value.clone())
227 }
228
229 fn evaluate_number_literal(&self, number_str: &str) -> Result<DataValue> {
231 if let Ok(int_val) = number_str.parse::<i64>() {
233 return Ok(DataValue::Integer(int_val));
234 }
235
236 if let Ok(float_val) = number_str.parse::<f64>() {
238 return Ok(DataValue::Float(float_val));
239 }
240
241 Err(anyhow!("Invalid number literal: {}", number_str))
242 }
243
244 fn evaluate_binary_op(
246 &mut self,
247 left: &SqlExpression,
248 op: &str,
249 right: &SqlExpression,
250 row_index: usize,
251 ) -> Result<DataValue> {
252 let left_val = self.evaluate(left, row_index)?;
253 let right_val = self.evaluate(right, row_index)?;
254
255 debug!(
256 "ArithmeticEvaluator: {} {} {}",
257 self.format_value(&left_val),
258 op,
259 self.format_value(&right_val)
260 );
261
262 match op {
263 "+" => self.add_values(&left_val, &right_val),
264 "-" => self.subtract_values(&left_val, &right_val),
265 "*" => self.multiply_values(&left_val, &right_val),
266 "/" => self.divide_values(&left_val, &right_val),
267 "%" => {
268 let args = vec![left.clone(), right.clone()];
270 self.evaluate_function("MOD", &args, row_index)
271 }
272 ">" | "<" | ">=" | "<=" | "=" | "!=" | "<>" => {
275 let result = compare_with_op(&left_val, &right_val, op, false);
276 Ok(DataValue::Boolean(result))
277 }
278 "IS NULL" => Ok(DataValue::Boolean(matches!(left_val, DataValue::Null))),
280 "IS NOT NULL" => Ok(DataValue::Boolean(!matches!(left_val, DataValue::Null))),
281 "AND" => {
283 let left_bool = self.to_bool(&left_val)?;
284 let right_bool = self.to_bool(&right_val)?;
285 Ok(DataValue::Boolean(left_bool && right_bool))
286 }
287 "OR" => {
288 let left_bool = self.to_bool(&left_val)?;
289 let right_bool = self.to_bool(&right_val)?;
290 Ok(DataValue::Boolean(left_bool || right_bool))
291 }
292 _ => Err(anyhow!("Unsupported arithmetic operator: {}", op)),
293 }
294 }
295
296 fn add_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
298 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
300 return Ok(DataValue::Null);
301 }
302
303 match (left, right) {
304 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a + b)),
305 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 + b)),
306 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a + *b as f64)),
307 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a + b)),
308 _ => Err(anyhow!("Cannot add {:?} and {:?}", left, right)),
309 }
310 }
311
312 fn subtract_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
314 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
316 return Ok(DataValue::Null);
317 }
318
319 match (left, right) {
320 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a - b)),
321 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 - b)),
322 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a - *b as f64)),
323 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a - b)),
324 _ => Err(anyhow!("Cannot subtract {:?} and {:?}", left, right)),
325 }
326 }
327
328 fn multiply_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
330 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
332 return Ok(DataValue::Null);
333 }
334
335 match (left, right) {
336 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a * b)),
337 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 * b)),
338 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a * *b as f64)),
339 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a * b)),
340 _ => Err(anyhow!("Cannot multiply {:?} and {:?}", left, right)),
341 }
342 }
343
344 fn divide_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
346 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
348 return Ok(DataValue::Null);
349 }
350
351 let is_zero = match right {
353 DataValue::Integer(0) => true,
354 DataValue::Float(f) if *f == 0.0 => true, _ => false,
356 };
357
358 if is_zero {
359 return Err(anyhow!("Division by zero"));
360 }
361
362 match (left, right) {
363 (DataValue::Integer(a), DataValue::Integer(b)) => {
364 if a % b == 0 {
366 Ok(DataValue::Integer(a / b))
367 } else {
368 Ok(DataValue::Float(*a as f64 / *b as f64))
369 }
370 }
371 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 / b)),
372 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a / *b as f64)),
373 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a / b)),
374 _ => Err(anyhow!("Cannot divide {:?} and {:?}", left, right)),
375 }
376 }
377
378 fn format_value(&self, value: &DataValue) -> String {
380 match value {
381 DataValue::Integer(i) => i.to_string(),
382 DataValue::Float(f) => f.to_string(),
383 DataValue::String(s) => format!("'{s}'"),
384 _ => format!("{value:?}"),
385 }
386 }
387
388 fn to_bool(&self, value: &DataValue) -> Result<bool> {
390 match value {
391 DataValue::Boolean(b) => Ok(*b),
392 DataValue::Integer(i) => Ok(*i != 0),
393 DataValue::Float(f) => Ok(*f != 0.0),
394 DataValue::Null => Ok(false),
395 _ => Err(anyhow!("Cannot convert {:?} to boolean", value)),
396 }
397 }
398
399 fn evaluate_function_with_distinct(
401 &mut self,
402 name: &str,
403 args: &[SqlExpression],
404 distinct: bool,
405 row_index: usize,
406 ) -> Result<DataValue> {
407 if distinct {
409 let name_upper = name.to_uppercase();
410
411 if name_upper == "COUNT"
413 || name_upper == "SUM"
414 || name_upper == "AVG"
415 || name_upper == "MIN"
416 || name_upper == "MAX"
417 {
418 return self.evaluate_aggregate_distinct(&name_upper, args, row_index);
419 } else {
420 return Err(anyhow!(
421 "DISTINCT can only be used with aggregate functions"
422 ));
423 }
424 }
425
426 self.evaluate_function(name, args, row_index)
428 }
429
430 fn evaluate_aggregate_distinct(
431 &mut self,
432 name: &str,
433 args: &[SqlExpression],
434 _row_index: usize,
435 ) -> Result<DataValue> {
436 use std::collections::HashSet;
437
438 if args.is_empty() {
439 return Err(anyhow!("{} DISTINCT requires at least one argument", name));
440 }
441
442 let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
444 visible.clone()
445 } else {
446 (0..self.table.rows.len()).collect()
447 };
448
449 let mut unique_values = HashSet::new();
451 let mut numeric_values = Vec::new();
452
453 for row_idx in &rows_to_process {
454 let value = self.evaluate(&args[0], *row_idx)?;
456
457 if matches!(value, DataValue::Null) {
459 continue;
460 }
461
462 let value_str = match &value {
464 DataValue::String(s) => s.clone(),
465 DataValue::InternedString(s) => s.to_string(),
466 DataValue::Integer(i) => i.to_string(),
467 DataValue::Float(f) => f.to_string(),
468 DataValue::Boolean(b) => b.to_string(),
469 DataValue::DateTime(dt) => dt.to_string(),
470 DataValue::Null => continue,
471 };
472
473 if unique_values.insert(value_str) {
475 if name != "COUNT" {
477 match value {
478 DataValue::Integer(i) => numeric_values.push(i as f64),
479 DataValue::Float(f) => numeric_values.push(f),
480 _ => {} }
482 }
483 }
484 }
485
486 match name {
488 "COUNT" => Ok(DataValue::Integer(unique_values.len() as i64)),
489 "SUM" => {
490 if numeric_values.is_empty() {
491 Ok(DataValue::Null)
492 } else {
493 let sum: f64 = numeric_values.iter().sum();
494 if sum.fract() == 0.0 && sum.abs() < 1e10 {
495 Ok(DataValue::Integer(sum as i64))
496 } else {
497 Ok(DataValue::Float(sum))
498 }
499 }
500 }
501 "AVG" => {
502 if numeric_values.is_empty() {
503 Ok(DataValue::Null)
504 } else {
505 let sum: f64 = numeric_values.iter().sum();
506 Ok(DataValue::Float(sum / numeric_values.len() as f64))
507 }
508 }
509 "MIN" => {
510 if numeric_values.is_empty() {
511 Ok(DataValue::Null)
512 } else {
513 let min = numeric_values.iter().fold(f64::INFINITY, |a, &b| a.min(b));
514 if min.fract() == 0.0 && min.abs() < 1e10 {
515 Ok(DataValue::Integer(min as i64))
516 } else {
517 Ok(DataValue::Float(min))
518 }
519 }
520 }
521 "MAX" => {
522 if numeric_values.is_empty() {
523 Ok(DataValue::Null)
524 } else {
525 let max = numeric_values
526 .iter()
527 .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
528 if max.fract() == 0.0 && max.abs() < 1e10 {
529 Ok(DataValue::Integer(max as i64))
530 } else {
531 Ok(DataValue::Float(max))
532 }
533 }
534 }
535 _ => Err(anyhow!("Unsupported DISTINCT aggregate: {}", name)),
536 }
537 }
538
539 fn evaluate_function(
540 &mut self,
541 name: &str,
542 args: &[SqlExpression],
543 row_index: usize,
544 ) -> Result<DataValue> {
545 let name_upper = name.to_uppercase();
547
548 if name_upper == "COUNT" && args.len() == 1 {
550 match &args[0] {
551 SqlExpression::Column(col) if col == "*" => {
552 let count = if let Some(ref visible) = self.visible_rows {
554 visible.len() as i64
555 } else {
556 self.table.rows.len() as i64
557 };
558 return Ok(DataValue::Integer(count));
559 }
560 SqlExpression::StringLiteral(s) if s == "*" => {
561 let count = if let Some(ref visible) = self.visible_rows {
563 visible.len() as i64
564 } else {
565 self.table.rows.len() as i64
566 };
567 return Ok(DataValue::Integer(count));
568 }
569 _ => {
570 }
572 }
573 }
574
575 if self.aggregate_registry.get(&name_upper).is_some() {
577 let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
579 visible.clone()
580 } else {
581 (0..self.table.rows.len()).collect()
582 };
583
584 let values = if !args.is_empty()
586 && !(args.len() == 1 && matches!(&args[0], SqlExpression::Column(c) if c == "*"))
587 {
588 let mut vals = Vec::new();
590 for &row_idx in &rows_to_process {
591 let value = self.evaluate(&args[0], row_idx)?;
592 vals.push(value);
593 }
594 Some(vals)
595 } else {
596 None
597 };
598
599 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
601 let mut state = agg_func.init();
602
603 if let Some(values) = values {
604 for value in &values {
606 agg_func.accumulate(&mut state, value)?;
607 }
608 } else {
609 for _ in &rows_to_process {
611 agg_func.accumulate(&mut state, &DataValue::Integer(1))?;
612 }
613 }
614
615 return Ok(agg_func.finalize(state));
616 }
617
618 if self.function_registry.get(name).is_some() {
620 let mut evaluated_args = Vec::new();
622 for arg in args {
623 evaluated_args.push(self.evaluate(arg, row_index)?);
624 }
625
626 let func = self.function_registry.get(name).unwrap();
628 return func.evaluate(&evaluated_args);
629 }
630
631 Err(anyhow!("Unknown function: {}", name))
633 }
634
635 fn get_or_create_window_context(&mut self, spec: &WindowSpec) -> Result<Arc<WindowContext>> {
637 let key = format!("{:?}", spec);
639
640 if let Some(context) = self.window_contexts.get(&key) {
641 return Ok(Arc::clone(context));
642 }
643
644 let data_view = if let Some(ref _visible_rows) = self.visible_rows {
646 let view = DataView::new(Arc::new(self.table.clone()));
648 view
651 } else {
652 DataView::new(Arc::new(self.table.clone()))
653 };
654
655 let context = WindowContext::new(
657 Arc::new(data_view),
658 spec.partition_by.clone(),
659 spec.order_by.clone(),
660 )?;
661
662 let context = Arc::new(context);
663 self.window_contexts.insert(key, Arc::clone(&context));
664 Ok(context)
665 }
666
667 fn evaluate_window_function(
669 &mut self,
670 name: &str,
671 args: &[SqlExpression],
672 spec: &WindowSpec,
673 row_index: usize,
674 ) -> Result<DataValue> {
675 let context = self.get_or_create_window_context(spec)?;
676 let name_upper = name.to_uppercase();
677
678 match name_upper.as_str() {
679 "LAG" => {
680 if args.is_empty() {
682 return Err(anyhow!("LAG requires at least 1 argument"));
683 }
684
685 let column = match &args[0] {
687 SqlExpression::Column(col) => col.clone(),
688 _ => return Err(anyhow!("LAG first argument must be a column")),
689 };
690
691 let offset = if args.len() > 1 {
693 match self.evaluate(&args[1], row_index)? {
694 DataValue::Integer(i) => i as i32,
695 _ => return Err(anyhow!("LAG offset must be an integer")),
696 }
697 } else {
698 1
699 };
700
701 Ok(context
703 .get_offset_value(row_index, -offset, &column)
704 .unwrap_or(DataValue::Null))
705 }
706 "LEAD" => {
707 if args.is_empty() {
709 return Err(anyhow!("LEAD requires at least 1 argument"));
710 }
711
712 let column = match &args[0] {
714 SqlExpression::Column(col) => col.clone(),
715 _ => return Err(anyhow!("LEAD first argument must be a column")),
716 };
717
718 let offset = if args.len() > 1 {
720 match self.evaluate(&args[1], row_index)? {
721 DataValue::Integer(i) => i as i32,
722 _ => return Err(anyhow!("LEAD offset must be an integer")),
723 }
724 } else {
725 1
726 };
727
728 Ok(context
730 .get_offset_value(row_index, offset, &column)
731 .unwrap_or(DataValue::Null))
732 }
733 "ROW_NUMBER" => {
734 Ok(DataValue::Integer(context.get_row_number(row_index) as i64))
736 }
737 "FIRST_VALUE" => {
738 if args.is_empty() {
740 return Err(anyhow!("FIRST_VALUE requires 1 argument"));
741 }
742
743 let column = match &args[0] {
744 SqlExpression::Column(col) => col.clone(),
745 _ => return Err(anyhow!("FIRST_VALUE argument must be a column")),
746 };
747
748 Ok(context
749 .get_first_value(row_index, &column)
750 .unwrap_or(DataValue::Null))
751 }
752 "LAST_VALUE" => {
753 if args.is_empty() {
755 return Err(anyhow!("LAST_VALUE requires 1 argument"));
756 }
757
758 let column = match &args[0] {
759 SqlExpression::Column(col) => col.clone(),
760 _ => return Err(anyhow!("LAST_VALUE argument must be a column")),
761 };
762
763 Ok(context
764 .get_last_value(row_index, &column)
765 .unwrap_or(DataValue::Null))
766 }
767 "SUM" => {
768 if args.is_empty() {
770 return Err(anyhow!("SUM requires 1 argument"));
771 }
772
773 let column = match &args[0] {
774 SqlExpression::Column(col) => col.clone(),
775 _ => return Err(anyhow!("SUM argument must be a column")),
776 };
777
778 Ok(context
779 .get_partition_sum(row_index, &column)
780 .unwrap_or(DataValue::Null))
781 }
782 "COUNT" => {
783 if args.is_empty() {
786 Ok(context
788 .get_partition_count(row_index, None)
789 .unwrap_or(DataValue::Null))
790 } else {
791 let column = match &args[0] {
793 SqlExpression::Column(col) => {
794 if col == "*" {
795 return Ok(context
797 .get_partition_count(row_index, None)
798 .unwrap_or(DataValue::Null));
799 }
800 col.clone()
801 }
802 SqlExpression::StringLiteral(s) if s == "*" => {
803 return Ok(context
805 .get_partition_count(row_index, None)
806 .unwrap_or(DataValue::Null));
807 }
808 _ => return Err(anyhow!("COUNT argument must be a column or *")),
809 };
810
811 Ok(context
813 .get_partition_count(row_index, Some(&column))
814 .unwrap_or(DataValue::Null))
815 }
816 }
817 _ => Err(anyhow!("Unknown window function: {}", name)),
818 }
819 }
820
821 fn evaluate_method_call(
823 &mut self,
824 object: &str,
825 method: &str,
826 args: &[SqlExpression],
827 row_index: usize,
828 ) -> Result<DataValue> {
829 let col_index = self.table.get_column_index(object).ok_or_else(|| {
831 let suggestion = self.find_similar_column(object);
832 match suggestion {
833 Some(similar) => {
834 anyhow!("Column '{}' not found. Did you mean '{}'?", object, similar)
835 }
836 None => anyhow!("Column '{}' not found", object),
837 }
838 })?;
839
840 let cell_value = self.table.get_value(row_index, col_index).cloned();
841
842 self.evaluate_method_on_value(
843 &cell_value.unwrap_or(DataValue::Null),
844 method,
845 args,
846 row_index,
847 )
848 }
849
850 fn evaluate_method_on_value(
852 &mut self,
853 value: &DataValue,
854 method: &str,
855 args: &[SqlExpression],
856 row_index: usize,
857 ) -> Result<DataValue> {
858 let function_name = match method.to_lowercase().as_str() {
863 "trim" => "TRIM",
864 "trimstart" | "trimbegin" => "TRIMSTART",
865 "trimend" => "TRIMEND",
866 "length" | "len" => "LENGTH",
867 "contains" => "CONTAINS",
868 "startswith" => "STARTSWITH",
869 "endswith" => "ENDSWITH",
870 "indexof" => "INDEXOF",
871 _ => method, };
873
874 if self.function_registry.get(function_name).is_some() {
876 debug!(
877 "Proxying method '{}' through function registry as '{}'",
878 method, function_name
879 );
880
881 let mut func_args = vec![value.clone()];
883
884 for arg in args {
886 func_args.push(self.evaluate(arg, row_index)?);
887 }
888
889 let func = self.function_registry.get(function_name).unwrap();
891 return func.evaluate(&func_args);
892 }
893
894 Err(anyhow!(
897 "Method '{}' not found. It should be registered in the function registry.",
898 method
899 ))
900 }
901
902 fn evaluate_case_expression(
904 &mut self,
905 when_branches: &[crate::sql::recursive_parser::WhenBranch],
906 else_branch: &Option<Box<SqlExpression>>,
907 row_index: usize,
908 ) -> Result<DataValue> {
909 debug!(
910 "ArithmeticEvaluator: evaluating CASE expression for row {}",
911 row_index
912 );
913
914 for branch in when_branches {
916 let condition_result = self.evaluate_condition_as_bool(&branch.condition, row_index)?;
918
919 if condition_result {
920 debug!("CASE: WHEN condition matched, evaluating result expression");
921 return self.evaluate(&branch.result, row_index);
922 }
923 }
924
925 if let Some(else_expr) = else_branch {
927 debug!("CASE: No WHEN matched, evaluating ELSE expression");
928 self.evaluate(else_expr, row_index)
929 } else {
930 debug!("CASE: No WHEN matched and no ELSE, returning NULL");
931 Ok(DataValue::Null)
932 }
933 }
934
935 fn evaluate_condition_as_bool(
937 &mut self,
938 expr: &SqlExpression,
939 row_index: usize,
940 ) -> Result<bool> {
941 let value = self.evaluate(expr, row_index)?;
942
943 match value {
944 DataValue::Boolean(b) => Ok(b),
945 DataValue::Integer(i) => Ok(i != 0),
946 DataValue::Float(f) => Ok(f != 0.0),
947 DataValue::Null => Ok(false),
948 DataValue::String(s) => Ok(!s.is_empty()),
949 DataValue::InternedString(s) => Ok(!s.is_empty()),
950 _ => Ok(true), }
952 }
953}
954
955#[cfg(test)]
956mod tests {
957 use super::*;
958 use crate::data::datatable::{DataColumn, DataRow};
959
960 fn create_test_table() -> DataTable {
961 let mut table = DataTable::new("test");
962 table.add_column(DataColumn::new("a"));
963 table.add_column(DataColumn::new("b"));
964 table.add_column(DataColumn::new("c"));
965
966 table
967 .add_row(DataRow::new(vec![
968 DataValue::Integer(10),
969 DataValue::Float(2.5),
970 DataValue::Integer(4),
971 ]))
972 .unwrap();
973
974 table
975 }
976
977 #[test]
978 fn test_evaluate_column() {
979 let table = create_test_table();
980 let mut evaluator = ArithmeticEvaluator::new(&table);
981
982 let expr = SqlExpression::Column("a".to_string());
983 let result = evaluator.evaluate(&expr, 0).unwrap();
984 assert_eq!(result, DataValue::Integer(10));
985 }
986
987 #[test]
988 fn test_evaluate_number_literal() {
989 let table = create_test_table();
990 let mut evaluator = ArithmeticEvaluator::new(&table);
991
992 let expr = SqlExpression::NumberLiteral("42".to_string());
993 let result = evaluator.evaluate(&expr, 0).unwrap();
994 assert_eq!(result, DataValue::Integer(42));
995
996 let expr = SqlExpression::NumberLiteral("3.14".to_string());
997 let result = evaluator.evaluate(&expr, 0).unwrap();
998 assert_eq!(result, DataValue::Float(3.14));
999 }
1000
1001 #[test]
1002 fn test_add_values() {
1003 let table = create_test_table();
1004 let mut evaluator = ArithmeticEvaluator::new(&table);
1005
1006 let result = evaluator
1008 .add_values(&DataValue::Integer(5), &DataValue::Integer(3))
1009 .unwrap();
1010 assert_eq!(result, DataValue::Integer(8));
1011
1012 let result = evaluator
1014 .add_values(&DataValue::Integer(5), &DataValue::Float(2.5))
1015 .unwrap();
1016 assert_eq!(result, DataValue::Float(7.5));
1017 }
1018
1019 #[test]
1020 fn test_multiply_values() {
1021 let table = create_test_table();
1022 let mut evaluator = ArithmeticEvaluator::new(&table);
1023
1024 let result = evaluator
1026 .multiply_values(&DataValue::Integer(4), &DataValue::Float(2.5))
1027 .unwrap();
1028 assert_eq!(result, DataValue::Float(10.0));
1029 }
1030
1031 #[test]
1032 fn test_divide_values() {
1033 let table = create_test_table();
1034 let mut evaluator = ArithmeticEvaluator::new(&table);
1035
1036 let result = evaluator
1038 .divide_values(&DataValue::Integer(10), &DataValue::Integer(2))
1039 .unwrap();
1040 assert_eq!(result, DataValue::Integer(5));
1041
1042 let result = evaluator
1044 .divide_values(&DataValue::Integer(10), &DataValue::Integer(3))
1045 .unwrap();
1046 assert_eq!(result, DataValue::Float(10.0 / 3.0));
1047 }
1048
1049 #[test]
1050 fn test_division_by_zero() {
1051 let table = create_test_table();
1052 let mut evaluator = ArithmeticEvaluator::new(&table);
1053
1054 let result = evaluator.divide_values(&DataValue::Integer(10), &DataValue::Integer(0));
1055 assert!(result.is_err());
1056 assert!(result.unwrap_err().to_string().contains("Division by zero"));
1057 }
1058
1059 #[test]
1060 fn test_binary_op_expression() {
1061 let table = create_test_table();
1062 let mut evaluator = ArithmeticEvaluator::new(&table);
1063
1064 let expr = SqlExpression::BinaryOp {
1066 left: Box::new(SqlExpression::Column("a".to_string())),
1067 op: "*".to_string(),
1068 right: Box::new(SqlExpression::Column("b".to_string())),
1069 };
1070
1071 let result = evaluator.evaluate(&expr, 0).unwrap();
1072 assert_eq!(result, DataValue::Float(25.0));
1073 }
1074}