1use crate::config::global::get_date_notation;
2use crate::data::data_view::DataView;
3use crate::data::datatable::{DataTable, DataValue};
4use crate::data::value_comparisons::compare_with_op;
5use crate::sql::aggregates::AggregateRegistry;
6use crate::sql::functions::FunctionRegistry;
7use crate::sql::recursive_parser::{SqlExpression, WindowSpec};
8use crate::sql::window_context::WindowContext;
9use anyhow::{anyhow, Result};
10use std::collections::HashMap;
11use std::sync::Arc;
12use tracing::debug;
13
14pub struct ArithmeticEvaluator<'a> {
17 table: &'a DataTable,
18 date_notation: String,
19 function_registry: Arc<FunctionRegistry>,
20 aggregate_registry: Arc<AggregateRegistry>,
21 visible_rows: Option<Vec<usize>>, window_contexts: HashMap<String, Arc<WindowContext>>, }
24
25impl<'a> ArithmeticEvaluator<'a> {
26 #[must_use]
27 pub fn new(table: &'a DataTable) -> Self {
28 Self {
29 table,
30 date_notation: get_date_notation(),
31 function_registry: Arc::new(FunctionRegistry::new()),
32 aggregate_registry: Arc::new(AggregateRegistry::new()),
33 visible_rows: None,
34 window_contexts: HashMap::new(),
35 }
36 }
37
38 #[must_use]
39 pub fn with_date_notation(table: &'a DataTable, date_notation: String) -> Self {
40 Self {
41 table,
42 date_notation,
43 function_registry: Arc::new(FunctionRegistry::new()),
44 aggregate_registry: Arc::new(AggregateRegistry::new()),
45 visible_rows: None,
46 window_contexts: HashMap::new(),
47 }
48 }
49
50 #[must_use]
52 pub fn with_visible_rows(mut self, rows: Vec<usize>) -> Self {
53 self.visible_rows = Some(rows);
54 self
55 }
56
57 #[must_use]
58 pub fn with_date_notation_and_registry(
59 table: &'a DataTable,
60 date_notation: String,
61 function_registry: Arc<FunctionRegistry>,
62 ) -> Self {
63 Self {
64 table,
65 date_notation,
66 function_registry,
67 aggregate_registry: Arc::new(AggregateRegistry::new()),
68 visible_rows: None,
69 window_contexts: HashMap::new(),
70 }
71 }
72
73 fn find_similar_column(&self, name: &str) -> Option<String> {
75 let columns = self.table.column_names();
76 let mut best_match: Option<(String, usize)> = None;
77
78 for col in columns {
79 let distance = self.edit_distance(&col.to_lowercase(), &name.to_lowercase());
80 let max_distance = if name.len() > 10 { 3 } else { 2 };
83 if distance <= max_distance {
84 match &best_match {
85 None => best_match = Some((col, distance)),
86 Some((_, best_dist)) if distance < *best_dist => {
87 best_match = Some((col, distance));
88 }
89 _ => {}
90 }
91 }
92 }
93
94 best_match.map(|(name, _)| name)
95 }
96
97 fn edit_distance(&self, s1: &str, s2: &str) -> usize {
99 crate::sql::functions::string_methods::EditDistanceFunction::calculate_edit_distance(s1, s2)
101 }
102
103 pub fn evaluate(&mut self, expr: &SqlExpression, row_index: usize) -> Result<DataValue> {
105 debug!(
106 "ArithmeticEvaluator: evaluating {:?} for row {}",
107 expr, row_index
108 );
109
110 match expr {
111 SqlExpression::Column(column_name) => self.evaluate_column(column_name, row_index),
112 SqlExpression::StringLiteral(s) => Ok(DataValue::String(s.clone())),
113 SqlExpression::BooleanLiteral(b) => Ok(DataValue::Boolean(*b)),
114 SqlExpression::NumberLiteral(n) => self.evaluate_number_literal(n),
115 SqlExpression::Null => Ok(DataValue::Null),
116 SqlExpression::BinaryOp { left, op, right } => {
117 self.evaluate_binary_op(left, op, right, row_index)
118 }
119 SqlExpression::FunctionCall {
120 name,
121 args,
122 distinct,
123 } => self.evaluate_function_with_distinct(name, args, *distinct, row_index),
124 SqlExpression::WindowFunction {
125 name,
126 args,
127 window_spec,
128 } => self.evaluate_window_function(name, args, window_spec, row_index),
129 SqlExpression::MethodCall {
130 object,
131 method,
132 args,
133 } => self.evaluate_method_call(object, method, args, row_index),
134 SqlExpression::ChainedMethodCall { base, method, args } => {
135 let base_value = self.evaluate(base, row_index)?;
137 self.evaluate_method_on_value(&base_value, method, args, row_index)
138 }
139 SqlExpression::CaseExpression {
140 when_branches,
141 else_branch,
142 } => self.evaluate_case_expression(when_branches, else_branch, row_index),
143 _ => Err(anyhow!(
144 "Unsupported expression type for arithmetic evaluation: {:?}",
145 expr
146 )),
147 }
148 }
149
150 fn evaluate_column(&self, column_name: &str, row_index: usize) -> Result<DataValue> {
152 let col_index = self.table.get_column_index(column_name).ok_or_else(|| {
153 let suggestion = self.find_similar_column(column_name);
154 match suggestion {
155 Some(similar) => anyhow!(
156 "Column '{}' not found. Did you mean '{}'?",
157 column_name,
158 similar
159 ),
160 None => anyhow!("Column '{}' not found", column_name),
161 }
162 })?;
163
164 if row_index >= self.table.row_count() {
165 return Err(anyhow!("Row index {} out of bounds", row_index));
166 }
167
168 let row = self
169 .table
170 .get_row(row_index)
171 .ok_or_else(|| anyhow!("Row {} not found", row_index))?;
172
173 let value = row
174 .get(col_index)
175 .ok_or_else(|| anyhow!("Column index {} out of bounds for row", col_index))?;
176
177 Ok(value.clone())
178 }
179
180 fn evaluate_number_literal(&self, number_str: &str) -> Result<DataValue> {
182 if let Ok(int_val) = number_str.parse::<i64>() {
184 return Ok(DataValue::Integer(int_val));
185 }
186
187 if let Ok(float_val) = number_str.parse::<f64>() {
189 return Ok(DataValue::Float(float_val));
190 }
191
192 Err(anyhow!("Invalid number literal: {}", number_str))
193 }
194
195 fn evaluate_binary_op(
197 &mut self,
198 left: &SqlExpression,
199 op: &str,
200 right: &SqlExpression,
201 row_index: usize,
202 ) -> Result<DataValue> {
203 let left_val = self.evaluate(left, row_index)?;
204 let right_val = self.evaluate(right, row_index)?;
205
206 debug!(
207 "ArithmeticEvaluator: {} {} {}",
208 self.format_value(&left_val),
209 op,
210 self.format_value(&right_val)
211 );
212
213 match op {
214 "+" => self.add_values(&left_val, &right_val),
215 "-" => self.subtract_values(&left_val, &right_val),
216 "*" => self.multiply_values(&left_val, &right_val),
217 "/" => self.divide_values(&left_val, &right_val),
218 "%" => {
219 let args = vec![left.clone(), right.clone()];
221 self.evaluate_function("MOD", &args, row_index)
222 }
223 ">" | "<" | ">=" | "<=" | "=" | "!=" | "<>" => {
226 let result = compare_with_op(&left_val, &right_val, op, false);
227 Ok(DataValue::Boolean(result))
228 }
229 "IS NULL" => Ok(DataValue::Boolean(matches!(left_val, DataValue::Null))),
231 "IS NOT NULL" => Ok(DataValue::Boolean(!matches!(left_val, DataValue::Null))),
232 "AND" => {
234 let left_bool = self.to_bool(&left_val)?;
235 let right_bool = self.to_bool(&right_val)?;
236 Ok(DataValue::Boolean(left_bool && right_bool))
237 }
238 "OR" => {
239 let left_bool = self.to_bool(&left_val)?;
240 let right_bool = self.to_bool(&right_val)?;
241 Ok(DataValue::Boolean(left_bool || right_bool))
242 }
243 _ => Err(anyhow!("Unsupported arithmetic operator: {}", op)),
244 }
245 }
246
247 fn add_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
249 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
251 return Ok(DataValue::Null);
252 }
253
254 match (left, right) {
255 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a + b)),
256 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 + b)),
257 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a + *b as f64)),
258 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a + b)),
259 _ => Err(anyhow!("Cannot add {:?} and {:?}", left, right)),
260 }
261 }
262
263 fn subtract_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
265 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
267 return Ok(DataValue::Null);
268 }
269
270 match (left, right) {
271 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a - b)),
272 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 - b)),
273 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a - *b as f64)),
274 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a - b)),
275 _ => Err(anyhow!("Cannot subtract {:?} and {:?}", left, right)),
276 }
277 }
278
279 fn multiply_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
281 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
283 return Ok(DataValue::Null);
284 }
285
286 match (left, right) {
287 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a * b)),
288 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 * b)),
289 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a * *b as f64)),
290 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a * b)),
291 _ => Err(anyhow!("Cannot multiply {:?} and {:?}", left, right)),
292 }
293 }
294
295 fn divide_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
297 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
299 return Ok(DataValue::Null);
300 }
301
302 let is_zero = match right {
304 DataValue::Integer(0) => true,
305 DataValue::Float(f) if *f == 0.0 => true, _ => false,
307 };
308
309 if is_zero {
310 return Err(anyhow!("Division by zero"));
311 }
312
313 match (left, right) {
314 (DataValue::Integer(a), DataValue::Integer(b)) => {
315 if a % b == 0 {
317 Ok(DataValue::Integer(a / b))
318 } else {
319 Ok(DataValue::Float(*a as f64 / *b as f64))
320 }
321 }
322 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 / b)),
323 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a / *b as f64)),
324 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a / b)),
325 _ => Err(anyhow!("Cannot divide {:?} and {:?}", left, right)),
326 }
327 }
328
329 fn format_value(&self, value: &DataValue) -> String {
331 match value {
332 DataValue::Integer(i) => i.to_string(),
333 DataValue::Float(f) => f.to_string(),
334 DataValue::String(s) => format!("'{s}'"),
335 _ => format!("{value:?}"),
336 }
337 }
338
339 fn to_bool(&self, value: &DataValue) -> Result<bool> {
341 match value {
342 DataValue::Boolean(b) => Ok(*b),
343 DataValue::Integer(i) => Ok(*i != 0),
344 DataValue::Float(f) => Ok(*f != 0.0),
345 DataValue::Null => Ok(false),
346 _ => Err(anyhow!("Cannot convert {:?} to boolean", value)),
347 }
348 }
349
350 fn evaluate_function_with_distinct(
352 &mut self,
353 name: &str,
354 args: &[SqlExpression],
355 distinct: bool,
356 row_index: usize,
357 ) -> Result<DataValue> {
358 if distinct {
360 let name_upper = name.to_uppercase();
361
362 if name_upper == "COUNT"
364 || name_upper == "SUM"
365 || name_upper == "AVG"
366 || name_upper == "MIN"
367 || name_upper == "MAX"
368 {
369 return self.evaluate_aggregate_distinct(&name_upper, args, row_index);
370 } else {
371 return Err(anyhow!(
372 "DISTINCT can only be used with aggregate functions"
373 ));
374 }
375 }
376
377 self.evaluate_function(name, args, row_index)
379 }
380
381 fn evaluate_aggregate_distinct(
382 &mut self,
383 name: &str,
384 args: &[SqlExpression],
385 _row_index: usize,
386 ) -> Result<DataValue> {
387 use std::collections::HashSet;
388
389 if args.is_empty() {
390 return Err(anyhow!("{} DISTINCT requires at least one argument", name));
391 }
392
393 let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
395 visible.clone()
396 } else {
397 (0..self.table.rows.len()).collect()
398 };
399
400 let mut unique_values = HashSet::new();
402 let mut numeric_values = Vec::new();
403
404 for row_idx in &rows_to_process {
405 let value = self.evaluate(&args[0], *row_idx)?;
407
408 if matches!(value, DataValue::Null) {
410 continue;
411 }
412
413 let value_str = match &value {
415 DataValue::String(s) => s.clone(),
416 DataValue::InternedString(s) => s.to_string(),
417 DataValue::Integer(i) => i.to_string(),
418 DataValue::Float(f) => f.to_string(),
419 DataValue::Boolean(b) => b.to_string(),
420 DataValue::DateTime(dt) => dt.to_string(),
421 DataValue::Null => continue,
422 };
423
424 if unique_values.insert(value_str) {
426 if name != "COUNT" {
428 match value {
429 DataValue::Integer(i) => numeric_values.push(i as f64),
430 DataValue::Float(f) => numeric_values.push(f),
431 _ => {} }
433 }
434 }
435 }
436
437 match name {
439 "COUNT" => Ok(DataValue::Integer(unique_values.len() as i64)),
440 "SUM" => {
441 if numeric_values.is_empty() {
442 Ok(DataValue::Null)
443 } else {
444 let sum: f64 = numeric_values.iter().sum();
445 if sum.fract() == 0.0 && sum.abs() < 1e10 {
446 Ok(DataValue::Integer(sum as i64))
447 } else {
448 Ok(DataValue::Float(sum))
449 }
450 }
451 }
452 "AVG" => {
453 if numeric_values.is_empty() {
454 Ok(DataValue::Null)
455 } else {
456 let sum: f64 = numeric_values.iter().sum();
457 Ok(DataValue::Float(sum / numeric_values.len() as f64))
458 }
459 }
460 "MIN" => {
461 if numeric_values.is_empty() {
462 Ok(DataValue::Null)
463 } else {
464 let min = numeric_values.iter().fold(f64::INFINITY, |a, &b| a.min(b));
465 if min.fract() == 0.0 && min.abs() < 1e10 {
466 Ok(DataValue::Integer(min as i64))
467 } else {
468 Ok(DataValue::Float(min))
469 }
470 }
471 }
472 "MAX" => {
473 if numeric_values.is_empty() {
474 Ok(DataValue::Null)
475 } else {
476 let max = numeric_values
477 .iter()
478 .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
479 if max.fract() == 0.0 && max.abs() < 1e10 {
480 Ok(DataValue::Integer(max as i64))
481 } else {
482 Ok(DataValue::Float(max))
483 }
484 }
485 }
486 _ => Err(anyhow!("Unsupported DISTINCT aggregate: {}", name)),
487 }
488 }
489
490 fn evaluate_function(
491 &mut self,
492 name: &str,
493 args: &[SqlExpression],
494 row_index: usize,
495 ) -> Result<DataValue> {
496 let name_upper = name.to_uppercase();
498
499 if name_upper == "COUNT" && args.len() == 1 {
501 match &args[0] {
502 SqlExpression::Column(col) if col == "*" => {
503 let count = if let Some(ref visible) = self.visible_rows {
505 visible.len() as i64
506 } else {
507 self.table.rows.len() as i64
508 };
509 return Ok(DataValue::Integer(count));
510 }
511 SqlExpression::StringLiteral(s) if s == "*" => {
512 let count = if let Some(ref visible) = self.visible_rows {
514 visible.len() as i64
515 } else {
516 self.table.rows.len() as i64
517 };
518 return Ok(DataValue::Integer(count));
519 }
520 _ => {
521 }
523 }
524 }
525
526 if self.aggregate_registry.get(&name_upper).is_some() {
528 let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
530 visible.clone()
531 } else {
532 (0..self.table.rows.len()).collect()
533 };
534
535 let values = if !args.is_empty()
537 && !(args.len() == 1 && matches!(&args[0], SqlExpression::Column(c) if c == "*"))
538 {
539 let mut vals = Vec::new();
541 for &row_idx in &rows_to_process {
542 let value = self.evaluate(&args[0], row_idx)?;
543 vals.push(value);
544 }
545 Some(vals)
546 } else {
547 None
548 };
549
550 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
552 let mut state = agg_func.init();
553
554 if let Some(values) = values {
555 for value in &values {
557 agg_func.accumulate(&mut state, value)?;
558 }
559 } else {
560 for _ in &rows_to_process {
562 agg_func.accumulate(&mut state, &DataValue::Integer(1))?;
563 }
564 }
565
566 return Ok(agg_func.finalize(state));
567 }
568
569 if self.function_registry.get(name).is_some() {
571 let mut evaluated_args = Vec::new();
573 for arg in args {
574 evaluated_args.push(self.evaluate(arg, row_index)?);
575 }
576
577 let func = self.function_registry.get(name).unwrap();
579 return func.evaluate(&evaluated_args);
580 }
581
582 Err(anyhow!("Unknown function: {}", name))
584 }
585
586 fn get_or_create_window_context(&mut self, spec: &WindowSpec) -> Result<Arc<WindowContext>> {
588 let key = format!("{:?}", spec);
590
591 if let Some(context) = self.window_contexts.get(&key) {
592 return Ok(Arc::clone(context));
593 }
594
595 let data_view = if let Some(ref _visible_rows) = self.visible_rows {
597 let view = DataView::new(Arc::new(self.table.clone()));
599 view
602 } else {
603 DataView::new(Arc::new(self.table.clone()))
604 };
605
606 let context = WindowContext::new(
608 Arc::new(data_view),
609 spec.partition_by.clone(),
610 spec.order_by.clone(),
611 )?;
612
613 let context = Arc::new(context);
614 self.window_contexts.insert(key, Arc::clone(&context));
615 Ok(context)
616 }
617
618 fn evaluate_window_function(
620 &mut self,
621 name: &str,
622 args: &[SqlExpression],
623 spec: &WindowSpec,
624 row_index: usize,
625 ) -> Result<DataValue> {
626 let context = self.get_or_create_window_context(spec)?;
627 let name_upper = name.to_uppercase();
628
629 match name_upper.as_str() {
630 "LAG" => {
631 if args.is_empty() {
633 return Err(anyhow!("LAG requires at least 1 argument"));
634 }
635
636 let column = match &args[0] {
638 SqlExpression::Column(col) => col.clone(),
639 _ => return Err(anyhow!("LAG first argument must be a column")),
640 };
641
642 let offset = if args.len() > 1 {
644 match self.evaluate(&args[1], row_index)? {
645 DataValue::Integer(i) => i as i32,
646 _ => return Err(anyhow!("LAG offset must be an integer")),
647 }
648 } else {
649 1
650 };
651
652 Ok(context
654 .get_offset_value(row_index, -offset, &column)
655 .unwrap_or(DataValue::Null))
656 }
657 "LEAD" => {
658 if args.is_empty() {
660 return Err(anyhow!("LEAD requires at least 1 argument"));
661 }
662
663 let column = match &args[0] {
665 SqlExpression::Column(col) => col.clone(),
666 _ => return Err(anyhow!("LEAD first argument must be a column")),
667 };
668
669 let offset = if args.len() > 1 {
671 match self.evaluate(&args[1], row_index)? {
672 DataValue::Integer(i) => i as i32,
673 _ => return Err(anyhow!("LEAD offset must be an integer")),
674 }
675 } else {
676 1
677 };
678
679 Ok(context
681 .get_offset_value(row_index, offset, &column)
682 .unwrap_or(DataValue::Null))
683 }
684 "ROW_NUMBER" => {
685 Ok(DataValue::Integer(context.get_row_number(row_index) as i64))
687 }
688 "FIRST_VALUE" => {
689 if args.is_empty() {
691 return Err(anyhow!("FIRST_VALUE requires 1 argument"));
692 }
693
694 let column = match &args[0] {
695 SqlExpression::Column(col) => col.clone(),
696 _ => return Err(anyhow!("FIRST_VALUE argument must be a column")),
697 };
698
699 Ok(context
700 .get_first_value(row_index, &column)
701 .unwrap_or(DataValue::Null))
702 }
703 "LAST_VALUE" => {
704 if args.is_empty() {
706 return Err(anyhow!("LAST_VALUE requires 1 argument"));
707 }
708
709 let column = match &args[0] {
710 SqlExpression::Column(col) => col.clone(),
711 _ => return Err(anyhow!("LAST_VALUE argument must be a column")),
712 };
713
714 Ok(context
715 .get_last_value(row_index, &column)
716 .unwrap_or(DataValue::Null))
717 }
718 "SUM" => {
719 if args.is_empty() {
721 return Err(anyhow!("SUM requires 1 argument"));
722 }
723
724 let column = match &args[0] {
725 SqlExpression::Column(col) => col.clone(),
726 _ => return Err(anyhow!("SUM argument must be a column")),
727 };
728
729 Ok(context
730 .get_partition_sum(row_index, &column)
731 .unwrap_or(DataValue::Null))
732 }
733 "COUNT" => {
734 if args.is_empty() {
737 Ok(context
739 .get_partition_count(row_index, None)
740 .unwrap_or(DataValue::Null))
741 } else {
742 let column = match &args[0] {
744 SqlExpression::Column(col) => {
745 if col == "*" {
746 return Ok(context
748 .get_partition_count(row_index, None)
749 .unwrap_or(DataValue::Null));
750 }
751 col.clone()
752 }
753 SqlExpression::StringLiteral(s) if s == "*" => {
754 return Ok(context
756 .get_partition_count(row_index, None)
757 .unwrap_or(DataValue::Null));
758 }
759 _ => return Err(anyhow!("COUNT argument must be a column or *")),
760 };
761
762 Ok(context
764 .get_partition_count(row_index, Some(&column))
765 .unwrap_or(DataValue::Null))
766 }
767 }
768 _ => Err(anyhow!("Unknown window function: {}", name)),
769 }
770 }
771
772 fn evaluate_method_call(
774 &mut self,
775 object: &str,
776 method: &str,
777 args: &[SqlExpression],
778 row_index: usize,
779 ) -> Result<DataValue> {
780 let col_index = self.table.get_column_index(object).ok_or_else(|| {
782 let suggestion = self.find_similar_column(object);
783 match suggestion {
784 Some(similar) => {
785 anyhow!("Column '{}' not found. Did you mean '{}'?", object, similar)
786 }
787 None => anyhow!("Column '{}' not found", object),
788 }
789 })?;
790
791 let cell_value = self.table.get_value(row_index, col_index).cloned();
792
793 self.evaluate_method_on_value(
794 &cell_value.unwrap_or(DataValue::Null),
795 method,
796 args,
797 row_index,
798 )
799 }
800
801 fn evaluate_method_on_value(
803 &mut self,
804 value: &DataValue,
805 method: &str,
806 args: &[SqlExpression],
807 row_index: usize,
808 ) -> Result<DataValue> {
809 let function_name = match method.to_lowercase().as_str() {
814 "trim" => "TRIM",
815 "trimstart" | "trimbegin" => "TRIMSTART",
816 "trimend" => "TRIMEND",
817 "length" | "len" => "LENGTH",
818 "contains" => "CONTAINS",
819 "startswith" => "STARTSWITH",
820 "endswith" => "ENDSWITH",
821 "indexof" => "INDEXOF",
822 _ => method, };
824
825 if self.function_registry.get(function_name).is_some() {
827 debug!(
828 "Proxying method '{}' through function registry as '{}'",
829 method, function_name
830 );
831
832 let mut func_args = vec![value.clone()];
834
835 for arg in args {
837 func_args.push(self.evaluate(arg, row_index)?);
838 }
839
840 let func = self.function_registry.get(function_name).unwrap();
842 return func.evaluate(&func_args);
843 }
844
845 match method.to_lowercase().as_str() {
847 "trim" | "trimstart" | "trimend" => {
848 if !args.is_empty() {
849 return Err(anyhow!("{} takes no arguments", method));
850 }
851
852 let str_val = match value {
854 DataValue::String(s) => s.clone(),
855 DataValue::InternedString(s) => s.to_string(),
856 DataValue::Integer(n) => n.to_string(),
857 DataValue::Float(f) => f.to_string(),
858 DataValue::Boolean(b) => b.to_string(),
859 DataValue::DateTime(dt) => dt.clone(),
860 DataValue::Null => return Ok(DataValue::Null),
861 };
862
863 let result = match method.to_lowercase().as_str() {
864 "trim" => str_val.trim().to_string(),
865 "trimstart" => str_val.trim_start().to_string(),
866 "trimend" => str_val.trim_end().to_string(),
867 _ => unreachable!(),
868 };
869
870 Ok(DataValue::String(result))
871 }
872 "length" => {
873 if !args.is_empty() {
874 return Err(anyhow!("Length takes no arguments"));
875 }
876
877 let len = match value {
879 DataValue::String(s) => s.len(),
880 DataValue::InternedString(s) => s.len(),
881 DataValue::Integer(n) => n.to_string().len(),
882 DataValue::Float(f) => f.to_string().len(),
883 DataValue::Boolean(b) => b.to_string().len(),
884 DataValue::DateTime(dt) => dt.len(),
885 DataValue::Null => return Ok(DataValue::Integer(0)),
886 };
887
888 Ok(DataValue::Integer(len as i64))
889 }
890 "indexof" => {
891 if args.len() != 1 {
892 return Err(anyhow!("IndexOf requires exactly 1 argument"));
893 }
894
895 let search_str = match self.evaluate(&args[0], row_index)? {
897 DataValue::String(s) => s,
898 DataValue::InternedString(s) => s.to_string(),
899 DataValue::Integer(n) => n.to_string(),
900 DataValue::Float(f) => f.to_string(),
901 _ => return Err(anyhow!("IndexOf argument must be a string")),
902 };
903
904 let str_val = match value {
906 DataValue::String(s) => s.clone(),
907 DataValue::InternedString(s) => s.to_string(),
908 DataValue::Integer(n) => n.to_string(),
909 DataValue::Float(f) => f.to_string(),
910 DataValue::Boolean(b) => b.to_string(),
911 DataValue::DateTime(dt) => dt.clone(),
912 DataValue::Null => return Ok(DataValue::Integer(-1)),
913 };
914
915 let index = str_val.find(&search_str).map_or(-1, |i| i as i64);
916
917 Ok(DataValue::Integer(index))
918 }
919 "contains" => {
920 if args.len() != 1 {
921 return Err(anyhow!("Contains requires exactly 1 argument"));
922 }
923
924 let search_str = match self.evaluate(&args[0], row_index)? {
926 DataValue::String(s) => s,
927 DataValue::InternedString(s) => s.to_string(),
928 DataValue::Integer(n) => n.to_string(),
929 DataValue::Float(f) => f.to_string(),
930 _ => return Err(anyhow!("Contains argument must be a string")),
931 };
932
933 let str_val = match value {
935 DataValue::String(s) => s.clone(),
936 DataValue::InternedString(s) => s.to_string(),
937 DataValue::Integer(n) => n.to_string(),
938 DataValue::Float(f) => f.to_string(),
939 DataValue::Boolean(b) => b.to_string(),
940 DataValue::DateTime(dt) => dt.clone(),
941 DataValue::Null => return Ok(DataValue::Boolean(false)),
942 };
943
944 let result = str_val.to_lowercase().contains(&search_str.to_lowercase());
946 Ok(DataValue::Boolean(result))
947 }
948 "startswith" => {
949 if args.len() != 1 {
950 return Err(anyhow!("StartsWith requires exactly 1 argument"));
951 }
952
953 let prefix = match self.evaluate(&args[0], row_index)? {
955 DataValue::String(s) => s,
956 DataValue::InternedString(s) => s.to_string(),
957 DataValue::Integer(n) => n.to_string(),
958 DataValue::Float(f) => f.to_string(),
959 _ => return Err(anyhow!("StartsWith argument must be a string")),
960 };
961
962 let str_val = match value {
964 DataValue::String(s) => s.clone(),
965 DataValue::InternedString(s) => s.to_string(),
966 DataValue::Integer(n) => n.to_string(),
967 DataValue::Float(f) => f.to_string(),
968 DataValue::Boolean(b) => b.to_string(),
969 DataValue::DateTime(dt) => dt.clone(),
970 DataValue::Null => return Ok(DataValue::Boolean(false)),
971 };
972
973 let result = str_val.to_lowercase().starts_with(&prefix.to_lowercase());
975 Ok(DataValue::Boolean(result))
976 }
977 "endswith" => {
978 if args.len() != 1 {
979 return Err(anyhow!("EndsWith requires exactly 1 argument"));
980 }
981
982 let suffix = match self.evaluate(&args[0], row_index)? {
984 DataValue::String(s) => s,
985 DataValue::InternedString(s) => s.to_string(),
986 DataValue::Integer(n) => n.to_string(),
987 DataValue::Float(f) => f.to_string(),
988 _ => return Err(anyhow!("EndsWith argument must be a string")),
989 };
990
991 let str_val = match value {
993 DataValue::String(s) => s.clone(),
994 DataValue::InternedString(s) => s.to_string(),
995 DataValue::Integer(n) => n.to_string(),
996 DataValue::Float(f) => f.to_string(),
997 DataValue::Boolean(b) => b.to_string(),
998 DataValue::DateTime(dt) => dt.clone(),
999 DataValue::Null => return Ok(DataValue::Boolean(false)),
1000 };
1001
1002 let result = str_val.to_lowercase().ends_with(&suffix.to_lowercase());
1004 Ok(DataValue::Boolean(result))
1005 }
1006 _ => Err(anyhow!("Unsupported method: {}", method)),
1007 }
1008 }
1009
1010 fn evaluate_case_expression(
1012 &mut self,
1013 when_branches: &[crate::sql::recursive_parser::WhenBranch],
1014 else_branch: &Option<Box<SqlExpression>>,
1015 row_index: usize,
1016 ) -> Result<DataValue> {
1017 debug!(
1018 "ArithmeticEvaluator: evaluating CASE expression for row {}",
1019 row_index
1020 );
1021
1022 for branch in when_branches {
1024 let condition_result = self.evaluate_condition_as_bool(&branch.condition, row_index)?;
1026
1027 if condition_result {
1028 debug!("CASE: WHEN condition matched, evaluating result expression");
1029 return self.evaluate(&branch.result, row_index);
1030 }
1031 }
1032
1033 if let Some(else_expr) = else_branch {
1035 debug!("CASE: No WHEN matched, evaluating ELSE expression");
1036 self.evaluate(else_expr, row_index)
1037 } else {
1038 debug!("CASE: No WHEN matched and no ELSE, returning NULL");
1039 Ok(DataValue::Null)
1040 }
1041 }
1042
1043 fn evaluate_condition_as_bool(
1045 &mut self,
1046 expr: &SqlExpression,
1047 row_index: usize,
1048 ) -> Result<bool> {
1049 let value = self.evaluate(expr, row_index)?;
1050
1051 match value {
1052 DataValue::Boolean(b) => Ok(b),
1053 DataValue::Integer(i) => Ok(i != 0),
1054 DataValue::Float(f) => Ok(f != 0.0),
1055 DataValue::Null => Ok(false),
1056 DataValue::String(s) => Ok(!s.is_empty()),
1057 DataValue::InternedString(s) => Ok(!s.is_empty()),
1058 _ => Ok(true), }
1060 }
1061}
1062
1063#[cfg(test)]
1064mod tests {
1065 use super::*;
1066 use crate::data::datatable::{DataColumn, DataRow};
1067
1068 fn create_test_table() -> DataTable {
1069 let mut table = DataTable::new("test");
1070 table.add_column(DataColumn::new("a"));
1071 table.add_column(DataColumn::new("b"));
1072 table.add_column(DataColumn::new("c"));
1073
1074 table
1075 .add_row(DataRow::new(vec![
1076 DataValue::Integer(10),
1077 DataValue::Float(2.5),
1078 DataValue::Integer(4),
1079 ]))
1080 .unwrap();
1081
1082 table
1083 }
1084
1085 #[test]
1086 fn test_evaluate_column() {
1087 let table = create_test_table();
1088 let mut evaluator = ArithmeticEvaluator::new(&table);
1089
1090 let expr = SqlExpression::Column("a".to_string());
1091 let result = evaluator.evaluate(&expr, 0).unwrap();
1092 assert_eq!(result, DataValue::Integer(10));
1093 }
1094
1095 #[test]
1096 fn test_evaluate_number_literal() {
1097 let table = create_test_table();
1098 let mut evaluator = ArithmeticEvaluator::new(&table);
1099
1100 let expr = SqlExpression::NumberLiteral("42".to_string());
1101 let result = evaluator.evaluate(&expr, 0).unwrap();
1102 assert_eq!(result, DataValue::Integer(42));
1103
1104 let expr = SqlExpression::NumberLiteral("3.14".to_string());
1105 let result = evaluator.evaluate(&expr, 0).unwrap();
1106 assert_eq!(result, DataValue::Float(3.14));
1107 }
1108
1109 #[test]
1110 fn test_add_values() {
1111 let table = create_test_table();
1112 let mut evaluator = ArithmeticEvaluator::new(&table);
1113
1114 let result = evaluator
1116 .add_values(&DataValue::Integer(5), &DataValue::Integer(3))
1117 .unwrap();
1118 assert_eq!(result, DataValue::Integer(8));
1119
1120 let result = evaluator
1122 .add_values(&DataValue::Integer(5), &DataValue::Float(2.5))
1123 .unwrap();
1124 assert_eq!(result, DataValue::Float(7.5));
1125 }
1126
1127 #[test]
1128 fn test_multiply_values() {
1129 let table = create_test_table();
1130 let mut evaluator = ArithmeticEvaluator::new(&table);
1131
1132 let result = evaluator
1134 .multiply_values(&DataValue::Integer(4), &DataValue::Float(2.5))
1135 .unwrap();
1136 assert_eq!(result, DataValue::Float(10.0));
1137 }
1138
1139 #[test]
1140 fn test_divide_values() {
1141 let table = create_test_table();
1142 let mut evaluator = ArithmeticEvaluator::new(&table);
1143
1144 let result = evaluator
1146 .divide_values(&DataValue::Integer(10), &DataValue::Integer(2))
1147 .unwrap();
1148 assert_eq!(result, DataValue::Integer(5));
1149
1150 let result = evaluator
1152 .divide_values(&DataValue::Integer(10), &DataValue::Integer(3))
1153 .unwrap();
1154 assert_eq!(result, DataValue::Float(10.0 / 3.0));
1155 }
1156
1157 #[test]
1158 fn test_division_by_zero() {
1159 let table = create_test_table();
1160 let mut evaluator = ArithmeticEvaluator::new(&table);
1161
1162 let result = evaluator.divide_values(&DataValue::Integer(10), &DataValue::Integer(0));
1163 assert!(result.is_err());
1164 assert!(result.unwrap_err().to_string().contains("Division by zero"));
1165 }
1166
1167 #[test]
1168 fn test_binary_op_expression() {
1169 let table = create_test_table();
1170 let mut evaluator = ArithmeticEvaluator::new(&table);
1171
1172 let expr = SqlExpression::BinaryOp {
1174 left: Box::new(SqlExpression::Column("a".to_string())),
1175 op: "*".to_string(),
1176 right: Box::new(SqlExpression::Column("b".to_string())),
1177 };
1178
1179 let result = evaluator.evaluate(&expr, 0).unwrap();
1180 assert_eq!(result, DataValue::Float(25.0));
1181 }
1182}