1use crate::data::data_view::DataView;
2use crate::data::datatable::{DataTable, DataValue};
3use crate::sql::aggregates::AggregateRegistry;
4use crate::sql::functions::FunctionRegistry;
5use crate::sql::recursive_parser::{SqlExpression, WindowSpec};
6use crate::sql::window_context::WindowContext;
7use anyhow::{anyhow, Result};
8use std::collections::HashMap;
9use std::sync::Arc;
10use tracing::debug;
11
12pub struct ArithmeticEvaluator<'a> {
15 table: &'a DataTable,
16 date_notation: String,
17 function_registry: Arc<FunctionRegistry>,
18 aggregate_registry: Arc<AggregateRegistry>,
19 visible_rows: Option<Vec<usize>>, window_contexts: HashMap<String, Arc<WindowContext>>, }
22
23impl<'a> ArithmeticEvaluator<'a> {
24 #[must_use]
25 pub fn new(table: &'a DataTable) -> Self {
26 Self {
27 table,
28 date_notation: "us".to_string(),
29 function_registry: Arc::new(FunctionRegistry::new()),
30 aggregate_registry: Arc::new(AggregateRegistry::new()),
31 visible_rows: None,
32 window_contexts: HashMap::new(),
33 }
34 }
35
36 #[must_use]
37 pub fn with_date_notation(table: &'a DataTable, date_notation: String) -> Self {
38 Self {
39 table,
40 date_notation,
41 function_registry: Arc::new(FunctionRegistry::new()),
42 aggregate_registry: Arc::new(AggregateRegistry::new()),
43 visible_rows: None,
44 window_contexts: HashMap::new(),
45 }
46 }
47
48 #[must_use]
50 pub fn with_visible_rows(mut self, rows: Vec<usize>) -> Self {
51 self.visible_rows = Some(rows);
52 self
53 }
54
55 #[must_use]
56 pub fn with_date_notation_and_registry(
57 table: &'a DataTable,
58 date_notation: String,
59 function_registry: Arc<FunctionRegistry>,
60 ) -> Self {
61 Self {
62 table,
63 date_notation,
64 function_registry,
65 aggregate_registry: Arc::new(AggregateRegistry::new()),
66 visible_rows: None,
67 window_contexts: HashMap::new(),
68 }
69 }
70
71 fn find_similar_column(&self, name: &str) -> Option<String> {
73 let columns = self.table.column_names();
74 let mut best_match: Option<(String, usize)> = None;
75
76 for col in columns {
77 let distance = self.edit_distance(&col.to_lowercase(), &name.to_lowercase());
78 let max_distance = if name.len() > 10 { 3 } else { 2 };
81 if distance <= max_distance {
82 match &best_match {
83 None => best_match = Some((col, distance)),
84 Some((_, best_dist)) if distance < *best_dist => {
85 best_match = Some((col, distance));
86 }
87 _ => {}
88 }
89 }
90 }
91
92 best_match.map(|(name, _)| name)
93 }
94
95 fn edit_distance(&self, s1: &str, s2: &str) -> usize {
97 crate::sql::functions::string_methods::EditDistanceFunction::calculate_edit_distance(s1, s2)
99 }
100
101 pub fn evaluate(&mut self, expr: &SqlExpression, row_index: usize) -> Result<DataValue> {
103 debug!(
104 "ArithmeticEvaluator: evaluating {:?} for row {}",
105 expr, row_index
106 );
107
108 match expr {
109 SqlExpression::Column(column_name) => self.evaluate_column(column_name, row_index),
110 SqlExpression::StringLiteral(s) => Ok(DataValue::String(s.clone())),
111 SqlExpression::BooleanLiteral(b) => Ok(DataValue::Boolean(*b)),
112 SqlExpression::NumberLiteral(n) => self.evaluate_number_literal(n),
113 SqlExpression::Null => Ok(DataValue::Null),
114 SqlExpression::BinaryOp { left, op, right } => {
115 self.evaluate_binary_op(left, op, right, row_index)
116 }
117 SqlExpression::FunctionCall { name, args } => {
118 self.evaluate_function(name, args, row_index)
119 }
120 SqlExpression::WindowFunction {
121 name,
122 args,
123 window_spec,
124 } => self.evaluate_window_function(name, args, window_spec, row_index),
125 SqlExpression::MethodCall {
126 object,
127 method,
128 args,
129 } => self.evaluate_method_call(object, method, args, row_index),
130 SqlExpression::ChainedMethodCall { base, method, args } => {
131 let base_value = self.evaluate(base, row_index)?;
133 self.evaluate_method_on_value(&base_value, method, args, row_index)
134 }
135 SqlExpression::CaseExpression {
136 when_branches,
137 else_branch,
138 } => self.evaluate_case_expression(when_branches, else_branch, row_index),
139 _ => Err(anyhow!(
140 "Unsupported expression type for arithmetic evaluation: {:?}",
141 expr
142 )),
143 }
144 }
145
146 fn evaluate_column(&self, column_name: &str, row_index: usize) -> Result<DataValue> {
148 let col_index = self.table.get_column_index(column_name).ok_or_else(|| {
149 let suggestion = self.find_similar_column(column_name);
150 match suggestion {
151 Some(similar) => anyhow!(
152 "Column '{}' not found. Did you mean '{}'?",
153 column_name,
154 similar
155 ),
156 None => anyhow!("Column '{}' not found", column_name),
157 }
158 })?;
159
160 if row_index >= self.table.row_count() {
161 return Err(anyhow!("Row index {} out of bounds", row_index));
162 }
163
164 let row = self
165 .table
166 .get_row(row_index)
167 .ok_or_else(|| anyhow!("Row {} not found", row_index))?;
168
169 let value = row
170 .get(col_index)
171 .ok_or_else(|| anyhow!("Column index {} out of bounds for row", col_index))?;
172
173 Ok(value.clone())
174 }
175
176 fn evaluate_number_literal(&self, number_str: &str) -> Result<DataValue> {
178 if let Ok(int_val) = number_str.parse::<i64>() {
180 return Ok(DataValue::Integer(int_val));
181 }
182
183 if let Ok(float_val) = number_str.parse::<f64>() {
185 return Ok(DataValue::Float(float_val));
186 }
187
188 Err(anyhow!("Invalid number literal: {}", number_str))
189 }
190
191 fn evaluate_binary_op(
193 &mut self,
194 left: &SqlExpression,
195 op: &str,
196 right: &SqlExpression,
197 row_index: usize,
198 ) -> Result<DataValue> {
199 let left_val = self.evaluate(left, row_index)?;
200 let right_val = self.evaluate(right, row_index)?;
201
202 debug!(
203 "ArithmeticEvaluator: {} {} {}",
204 self.format_value(&left_val),
205 op,
206 self.format_value(&right_val)
207 );
208
209 match op {
210 "+" => self.add_values(&left_val, &right_val),
211 "-" => self.subtract_values(&left_val, &right_val),
212 "*" => self.multiply_values(&left_val, &right_val),
213 "/" => self.divide_values(&left_val, &right_val),
214 "%" => {
215 let args = vec![left.clone(), right.clone()];
217 self.evaluate_function("MOD", &args, row_index)
218 }
219 ">" => self.compare_values(&left_val, &right_val, |a, b| a > b),
221 "<" => self.compare_values(&left_val, &right_val, |a, b| a < b),
222 ">=" => self.compare_values(&left_val, &right_val, |a, b| a >= b),
223 "<=" => self.compare_values(&left_val, &right_val, |a, b| a <= b),
224 "=" => self.compare_values(&left_val, &right_val, |a, b| a == b),
225 "!=" | "<>" => self.compare_values(&left_val, &right_val, |a, b| a != b),
226 "IS NULL" => Ok(DataValue::Boolean(matches!(left_val, DataValue::Null))),
228 "IS NOT NULL" => Ok(DataValue::Boolean(!matches!(left_val, DataValue::Null))),
229 _ => Err(anyhow!("Unsupported arithmetic operator: {}", op)),
230 }
231 }
232
233 fn add_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
235 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
237 return Ok(DataValue::Null);
238 }
239
240 match (left, right) {
241 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a + b)),
242 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 + b)),
243 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a + *b as f64)),
244 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a + b)),
245 _ => Err(anyhow!("Cannot add {:?} and {:?}", left, right)),
246 }
247 }
248
249 fn subtract_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
251 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
253 return Ok(DataValue::Null);
254 }
255
256 match (left, right) {
257 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a - b)),
258 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 - b)),
259 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a - *b as f64)),
260 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a - b)),
261 _ => Err(anyhow!("Cannot subtract {:?} and {:?}", left, right)),
262 }
263 }
264
265 fn multiply_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
267 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
269 return Ok(DataValue::Null);
270 }
271
272 match (left, right) {
273 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a * b)),
274 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 * b)),
275 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a * *b as f64)),
276 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a * b)),
277 _ => Err(anyhow!("Cannot multiply {:?} and {:?}", left, right)),
278 }
279 }
280
281 fn divide_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
283 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
285 return Ok(DataValue::Null);
286 }
287
288 let is_zero = match right {
290 DataValue::Integer(0) => true,
291 DataValue::Float(f) if *f == 0.0 => true, _ => false,
293 };
294
295 if is_zero {
296 return Err(anyhow!("Division by zero"));
297 }
298
299 match (left, right) {
300 (DataValue::Integer(a), DataValue::Integer(b)) => {
301 if a % b == 0 {
303 Ok(DataValue::Integer(a / b))
304 } else {
305 Ok(DataValue::Float(*a as f64 / *b as f64))
306 }
307 }
308 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 / b)),
309 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a / *b as f64)),
310 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a / b)),
311 _ => Err(anyhow!("Cannot divide {:?} and {:?}", left, right)),
312 }
313 }
314
315 fn format_value(&self, value: &DataValue) -> String {
317 match value {
318 DataValue::Integer(i) => i.to_string(),
319 DataValue::Float(f) => f.to_string(),
320 DataValue::String(s) => format!("'{s}'"),
321 _ => format!("{value:?}"),
322 }
323 }
324
325 fn compare_values<F>(&self, left: &DataValue, right: &DataValue, op: F) -> Result<DataValue>
327 where
328 F: Fn(f64, f64) -> bool,
329 {
330 debug!(
331 "ArithmeticEvaluator: comparing values {:?} and {:?}",
332 left, right
333 );
334
335 let result = match (left, right) {
336 (DataValue::Integer(a), DataValue::Integer(b)) => op(*a as f64, *b as f64),
338 (DataValue::Integer(a), DataValue::Float(b)) => op(*a as f64, *b),
339 (DataValue::Float(a), DataValue::Integer(b)) => op(*a, *b as f64),
340 (DataValue::Float(a), DataValue::Float(b)) => op(*a, *b),
341
342 (DataValue::String(a), DataValue::String(b)) => {
344 let a_num = a.parse::<f64>();
345 let b_num = b.parse::<f64>();
346 match (a_num, b_num) {
347 (Ok(a_val), Ok(b_val)) => op(a_val, b_val), _ => op(a.len() as f64, b.len() as f64), }
350 }
351 (DataValue::InternedString(a), DataValue::InternedString(b)) => {
352 let a_num = a.parse::<f64>();
353 let b_num = b.parse::<f64>();
354 match (a_num, b_num) {
355 (Ok(a_val), Ok(b_val)) => op(a_val, b_val), _ => op(a.len() as f64, b.len() as f64), }
358 }
359 (DataValue::String(a), DataValue::InternedString(b)) => {
360 let a_num = a.parse::<f64>();
361 let b_num = b.parse::<f64>();
362 match (a_num, b_num) {
363 (Ok(a_val), Ok(b_val)) => op(a_val, b_val), _ => op(a.len() as f64, b.len() as f64), }
366 }
367 (DataValue::InternedString(a), DataValue::String(b)) => {
368 let a_num = a.parse::<f64>();
369 let b_num = b.parse::<f64>();
370 match (a_num, b_num) {
371 (Ok(a_val), Ok(b_val)) => op(a_val, b_val), _ => op(a.len() as f64, b.len() as f64), }
374 }
375
376 (DataValue::String(a), DataValue::Integer(b)) => {
378 match a.parse::<f64>() {
379 Ok(a_val) => op(a_val, *b as f64),
380 Err(_) => false, }
382 }
383 (DataValue::Integer(a), DataValue::String(b)) => {
384 match b.parse::<f64>() {
385 Ok(b_val) => op(*a as f64, b_val),
386 Err(_) => false, }
388 }
389 (DataValue::String(a), DataValue::Float(b)) => match a.parse::<f64>() {
390 Ok(a_val) => op(a_val, *b),
391 Err(_) => false,
392 },
393 (DataValue::Float(a), DataValue::String(b)) => match b.parse::<f64>() {
394 Ok(b_val) => op(*a, b_val),
395 Err(_) => false,
396 },
397
398 (DataValue::Null, _) | (_, DataValue::Null) => false,
400
401 (DataValue::Boolean(a), DataValue::Boolean(b)) => {
403 op(if *a { 1.0 } else { 0.0 }, if *b { 1.0 } else { 0.0 })
404 }
405
406 _ => {
407 debug!(
408 "ArithmeticEvaluator: unsupported comparison between {:?} and {:?}",
409 left, right
410 );
411 false
412 }
413 };
414
415 debug!("ArithmeticEvaluator: comparison result: {}", result);
416 Ok(DataValue::Boolean(result))
417 }
418
419 fn evaluate_function(
421 &mut self,
422 name: &str,
423 args: &[SqlExpression],
424 row_index: usize,
425 ) -> Result<DataValue> {
426 let name_upper = name.to_uppercase();
428
429 if name_upper == "COUNT" && args.len() == 1 {
431 match &args[0] {
432 SqlExpression::Column(col) if col == "*" => {
433 let count = if let Some(ref visible) = self.visible_rows {
435 visible.len() as i64
436 } else {
437 self.table.rows.len() as i64
438 };
439 return Ok(DataValue::Integer(count));
440 }
441 SqlExpression::StringLiteral(s) if s == "*" => {
442 let count = if let Some(ref visible) = self.visible_rows {
444 visible.len() as i64
445 } else {
446 self.table.rows.len() as i64
447 };
448 return Ok(DataValue::Integer(count));
449 }
450 _ => {
451 }
453 }
454 }
455
456 if self.aggregate_registry.get(&name_upper).is_some() {
458 let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
460 visible.clone()
461 } else {
462 (0..self.table.rows.len()).collect()
463 };
464
465 let values = if !args.is_empty()
467 && !(args.len() == 1 && matches!(&args[0], SqlExpression::Column(c) if c == "*"))
468 {
469 let mut vals = Vec::new();
471 for &row_idx in &rows_to_process {
472 let value = self.evaluate(&args[0], row_idx)?;
473 vals.push(value);
474 }
475 Some(vals)
476 } else {
477 None
478 };
479
480 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
482 let mut state = agg_func.init();
483
484 if let Some(values) = values {
485 for value in &values {
487 agg_func.accumulate(&mut state, value)?;
488 }
489 } else {
490 for _ in &rows_to_process {
492 agg_func.accumulate(&mut state, &DataValue::Integer(1))?;
493 }
494 }
495
496 return Ok(agg_func.finalize(state));
497 }
498
499 if self.function_registry.get(name).is_some() {
501 let mut evaluated_args = Vec::new();
503 for arg in args {
504 evaluated_args.push(self.evaluate(arg, row_index)?);
505 }
506
507 let func = self.function_registry.get(name).unwrap();
509 return func.evaluate(&evaluated_args);
510 }
511
512 Err(anyhow!("Unknown function: {}", name))
514 }
515
516 fn get_or_create_window_context(&mut self, spec: &WindowSpec) -> Result<Arc<WindowContext>> {
518 let key = format!("{:?}", spec);
520
521 if let Some(context) = self.window_contexts.get(&key) {
522 return Ok(Arc::clone(context));
523 }
524
525 let data_view = if let Some(ref visible_rows) = self.visible_rows {
527 let mut view = DataView::new(Arc::new(self.table.clone()));
529 view
532 } else {
533 DataView::new(Arc::new(self.table.clone()))
534 };
535
536 let context = WindowContext::new(
538 Arc::new(data_view),
539 spec.partition_by.clone(),
540 spec.order_by.clone(),
541 )?;
542
543 let context = Arc::new(context);
544 self.window_contexts.insert(key, Arc::clone(&context));
545 Ok(context)
546 }
547
548 fn evaluate_window_function(
550 &mut self,
551 name: &str,
552 args: &[SqlExpression],
553 spec: &WindowSpec,
554 row_index: usize,
555 ) -> Result<DataValue> {
556 let context = self.get_or_create_window_context(spec)?;
557 let name_upper = name.to_uppercase();
558
559 match name_upper.as_str() {
560 "LAG" => {
561 if args.is_empty() {
563 return Err(anyhow!("LAG requires at least 1 argument"));
564 }
565
566 let column = match &args[0] {
568 SqlExpression::Column(col) => col.clone(),
569 _ => return Err(anyhow!("LAG first argument must be a column")),
570 };
571
572 let offset = if args.len() > 1 {
574 match self.evaluate(&args[1], row_index)? {
575 DataValue::Integer(i) => i as i32,
576 _ => return Err(anyhow!("LAG offset must be an integer")),
577 }
578 } else {
579 1
580 };
581
582 Ok(context
584 .get_offset_value(row_index, -offset, &column)
585 .unwrap_or(DataValue::Null))
586 }
587 "LEAD" => {
588 if args.is_empty() {
590 return Err(anyhow!("LEAD requires at least 1 argument"));
591 }
592
593 let column = match &args[0] {
595 SqlExpression::Column(col) => col.clone(),
596 _ => return Err(anyhow!("LEAD first argument must be a column")),
597 };
598
599 let offset = if args.len() > 1 {
601 match self.evaluate(&args[1], row_index)? {
602 DataValue::Integer(i) => i as i32,
603 _ => return Err(anyhow!("LEAD offset must be an integer")),
604 }
605 } else {
606 1
607 };
608
609 Ok(context
611 .get_offset_value(row_index, offset, &column)
612 .unwrap_or(DataValue::Null))
613 }
614 "ROW_NUMBER" => {
615 Ok(DataValue::Integer(context.get_row_number(row_index) as i64))
617 }
618 "FIRST_VALUE" => {
619 if args.is_empty() {
621 return Err(anyhow!("FIRST_VALUE requires 1 argument"));
622 }
623
624 let column = match &args[0] {
625 SqlExpression::Column(col) => col.clone(),
626 _ => return Err(anyhow!("FIRST_VALUE argument must be a column")),
627 };
628
629 Ok(context
630 .get_first_value(row_index, &column)
631 .unwrap_or(DataValue::Null))
632 }
633 "LAST_VALUE" => {
634 if args.is_empty() {
636 return Err(anyhow!("LAST_VALUE requires 1 argument"));
637 }
638
639 let column = match &args[0] {
640 SqlExpression::Column(col) => col.clone(),
641 _ => return Err(anyhow!("LAST_VALUE argument must be a column")),
642 };
643
644 Ok(context
645 .get_last_value(row_index, &column)
646 .unwrap_or(DataValue::Null))
647 }
648 "SUM" => {
649 if args.is_empty() {
651 return Err(anyhow!("SUM requires 1 argument"));
652 }
653
654 let column = match &args[0] {
655 SqlExpression::Column(col) => col.clone(),
656 _ => return Err(anyhow!("SUM argument must be a column")),
657 };
658
659 Ok(context
660 .get_partition_sum(row_index, &column)
661 .unwrap_or(DataValue::Null))
662 }
663 "COUNT" => {
664 if args.is_empty() {
667 Ok(context
669 .get_partition_count(row_index, None)
670 .unwrap_or(DataValue::Null))
671 } else {
672 let column = match &args[0] {
674 SqlExpression::Column(col) => {
675 if col == "*" {
676 return Ok(context
678 .get_partition_count(row_index, None)
679 .unwrap_or(DataValue::Null));
680 }
681 col.clone()
682 }
683 SqlExpression::StringLiteral(s) if s == "*" => {
684 return Ok(context
686 .get_partition_count(row_index, None)
687 .unwrap_or(DataValue::Null));
688 }
689 _ => return Err(anyhow!("COUNT argument must be a column or *")),
690 };
691
692 Ok(context
694 .get_partition_count(row_index, Some(&column))
695 .unwrap_or(DataValue::Null))
696 }
697 }
698 _ => Err(anyhow!("Unknown window function: {}", name)),
699 }
700 }
701
702 fn evaluate_method_call(
704 &mut self,
705 object: &str,
706 method: &str,
707 args: &[SqlExpression],
708 row_index: usize,
709 ) -> Result<DataValue> {
710 let col_index = self.table.get_column_index(object).ok_or_else(|| {
712 let suggestion = self.find_similar_column(object);
713 match suggestion {
714 Some(similar) => {
715 anyhow!("Column '{}' not found. Did you mean '{}'?", object, similar)
716 }
717 None => anyhow!("Column '{}' not found", object),
718 }
719 })?;
720
721 let cell_value = self.table.get_value(row_index, col_index).cloned();
722
723 self.evaluate_method_on_value(
724 &cell_value.unwrap_or(DataValue::Null),
725 method,
726 args,
727 row_index,
728 )
729 }
730
731 fn evaluate_method_on_value(
733 &mut self,
734 value: &DataValue,
735 method: &str,
736 args: &[SqlExpression],
737 row_index: usize,
738 ) -> Result<DataValue> {
739 let function_name = match method.to_lowercase().as_str() {
744 "trim" => "TRIM",
745 "trimstart" | "trimbegin" => "TRIMSTART",
746 "trimend" => "TRIMEND",
747 "length" | "len" => "LENGTH",
748 "contains" => "CONTAINS",
749 "startswith" => "STARTSWITH",
750 "endswith" => "ENDSWITH",
751 "indexof" => "INDEXOF",
752 _ => method, };
754
755 if self.function_registry.get(function_name).is_some() {
757 debug!(
758 "Proxying method '{}' through function registry as '{}'",
759 method, function_name
760 );
761
762 let mut func_args = vec![value.clone()];
764
765 for arg in args {
767 func_args.push(self.evaluate(arg, row_index)?);
768 }
769
770 let func = self.function_registry.get(function_name).unwrap();
772 return func.evaluate(&func_args);
773 }
774
775 match method.to_lowercase().as_str() {
777 "trim" | "trimstart" | "trimend" => {
778 if !args.is_empty() {
779 return Err(anyhow!("{} takes no arguments", method));
780 }
781
782 let str_val = match value {
784 DataValue::String(s) => s.clone(),
785 DataValue::InternedString(s) => s.to_string(),
786 DataValue::Integer(n) => n.to_string(),
787 DataValue::Float(f) => f.to_string(),
788 DataValue::Boolean(b) => b.to_string(),
789 DataValue::DateTime(dt) => dt.clone(),
790 DataValue::Null => return Ok(DataValue::Null),
791 };
792
793 let result = match method.to_lowercase().as_str() {
794 "trim" => str_val.trim().to_string(),
795 "trimstart" => str_val.trim_start().to_string(),
796 "trimend" => str_val.trim_end().to_string(),
797 _ => unreachable!(),
798 };
799
800 Ok(DataValue::String(result))
801 }
802 "length" => {
803 if !args.is_empty() {
804 return Err(anyhow!("Length takes no arguments"));
805 }
806
807 let len = match value {
809 DataValue::String(s) => s.len(),
810 DataValue::InternedString(s) => s.len(),
811 DataValue::Integer(n) => n.to_string().len(),
812 DataValue::Float(f) => f.to_string().len(),
813 DataValue::Boolean(b) => b.to_string().len(),
814 DataValue::DateTime(dt) => dt.len(),
815 DataValue::Null => return Ok(DataValue::Integer(0)),
816 };
817
818 Ok(DataValue::Integer(len as i64))
819 }
820 "indexof" => {
821 if args.len() != 1 {
822 return Err(anyhow!("IndexOf requires exactly 1 argument"));
823 }
824
825 let search_str = match self.evaluate(&args[0], row_index)? {
827 DataValue::String(s) => s,
828 DataValue::InternedString(s) => s.to_string(),
829 DataValue::Integer(n) => n.to_string(),
830 DataValue::Float(f) => f.to_string(),
831 _ => return Err(anyhow!("IndexOf argument must be a string")),
832 };
833
834 let str_val = match value {
836 DataValue::String(s) => s.clone(),
837 DataValue::InternedString(s) => s.to_string(),
838 DataValue::Integer(n) => n.to_string(),
839 DataValue::Float(f) => f.to_string(),
840 DataValue::Boolean(b) => b.to_string(),
841 DataValue::DateTime(dt) => dt.clone(),
842 DataValue::Null => return Ok(DataValue::Integer(-1)),
843 };
844
845 let index = str_val.find(&search_str).map_or(-1, |i| i as i64);
846
847 Ok(DataValue::Integer(index))
848 }
849 "contains" => {
850 if args.len() != 1 {
851 return Err(anyhow!("Contains requires exactly 1 argument"));
852 }
853
854 let search_str = match self.evaluate(&args[0], row_index)? {
856 DataValue::String(s) => s,
857 DataValue::InternedString(s) => s.to_string(),
858 DataValue::Integer(n) => n.to_string(),
859 DataValue::Float(f) => f.to_string(),
860 _ => return Err(anyhow!("Contains argument must be a string")),
861 };
862
863 let str_val = match value {
865 DataValue::String(s) => s.clone(),
866 DataValue::InternedString(s) => s.to_string(),
867 DataValue::Integer(n) => n.to_string(),
868 DataValue::Float(f) => f.to_string(),
869 DataValue::Boolean(b) => b.to_string(),
870 DataValue::DateTime(dt) => dt.clone(),
871 DataValue::Null => return Ok(DataValue::Boolean(false)),
872 };
873
874 let result = str_val.to_lowercase().contains(&search_str.to_lowercase());
876 Ok(DataValue::Boolean(result))
877 }
878 "startswith" => {
879 if args.len() != 1 {
880 return Err(anyhow!("StartsWith requires exactly 1 argument"));
881 }
882
883 let prefix = match self.evaluate(&args[0], row_index)? {
885 DataValue::String(s) => s,
886 DataValue::InternedString(s) => s.to_string(),
887 DataValue::Integer(n) => n.to_string(),
888 DataValue::Float(f) => f.to_string(),
889 _ => return Err(anyhow!("StartsWith argument must be a string")),
890 };
891
892 let str_val = match value {
894 DataValue::String(s) => s.clone(),
895 DataValue::InternedString(s) => s.to_string(),
896 DataValue::Integer(n) => n.to_string(),
897 DataValue::Float(f) => f.to_string(),
898 DataValue::Boolean(b) => b.to_string(),
899 DataValue::DateTime(dt) => dt.clone(),
900 DataValue::Null => return Ok(DataValue::Boolean(false)),
901 };
902
903 let result = str_val.to_lowercase().starts_with(&prefix.to_lowercase());
905 Ok(DataValue::Boolean(result))
906 }
907 "endswith" => {
908 if args.len() != 1 {
909 return Err(anyhow!("EndsWith requires exactly 1 argument"));
910 }
911
912 let suffix = match self.evaluate(&args[0], row_index)? {
914 DataValue::String(s) => s,
915 DataValue::InternedString(s) => s.to_string(),
916 DataValue::Integer(n) => n.to_string(),
917 DataValue::Float(f) => f.to_string(),
918 _ => return Err(anyhow!("EndsWith argument must be a string")),
919 };
920
921 let str_val = match value {
923 DataValue::String(s) => s.clone(),
924 DataValue::InternedString(s) => s.to_string(),
925 DataValue::Integer(n) => n.to_string(),
926 DataValue::Float(f) => f.to_string(),
927 DataValue::Boolean(b) => b.to_string(),
928 DataValue::DateTime(dt) => dt.clone(),
929 DataValue::Null => return Ok(DataValue::Boolean(false)),
930 };
931
932 let result = str_val.to_lowercase().ends_with(&suffix.to_lowercase());
934 Ok(DataValue::Boolean(result))
935 }
936 _ => Err(anyhow!("Unsupported method: {}", method)),
937 }
938 }
939
940 fn evaluate_case_expression(
942 &mut self,
943 when_branches: &[crate::sql::recursive_parser::WhenBranch],
944 else_branch: &Option<Box<SqlExpression>>,
945 row_index: usize,
946 ) -> Result<DataValue> {
947 debug!(
948 "ArithmeticEvaluator: evaluating CASE expression for row {}",
949 row_index
950 );
951
952 for branch in when_branches {
954 let condition_result = self.evaluate_condition_as_bool(&branch.condition, row_index)?;
956
957 if condition_result {
958 debug!("CASE: WHEN condition matched, evaluating result expression");
959 return self.evaluate(&branch.result, row_index);
960 }
961 }
962
963 if let Some(else_expr) = else_branch {
965 debug!("CASE: No WHEN matched, evaluating ELSE expression");
966 self.evaluate(else_expr, row_index)
967 } else {
968 debug!("CASE: No WHEN matched and no ELSE, returning NULL");
969 Ok(DataValue::Null)
970 }
971 }
972
973 fn evaluate_condition_as_bool(
975 &mut self,
976 expr: &SqlExpression,
977 row_index: usize,
978 ) -> Result<bool> {
979 let value = self.evaluate(expr, row_index)?;
980
981 match value {
982 DataValue::Boolean(b) => Ok(b),
983 DataValue::Integer(i) => Ok(i != 0),
984 DataValue::Float(f) => Ok(f != 0.0),
985 DataValue::Null => Ok(false),
986 DataValue::String(s) => Ok(!s.is_empty()),
987 DataValue::InternedString(s) => Ok(!s.is_empty()),
988 _ => Ok(true), }
990 }
991}
992
993#[cfg(test)]
994mod tests {
995 use super::*;
996 use crate::data::datatable::{DataColumn, DataRow};
997
998 fn create_test_table() -> DataTable {
999 let mut table = DataTable::new("test");
1000 table.add_column(DataColumn::new("a"));
1001 table.add_column(DataColumn::new("b"));
1002 table.add_column(DataColumn::new("c"));
1003
1004 table
1005 .add_row(DataRow::new(vec![
1006 DataValue::Integer(10),
1007 DataValue::Float(2.5),
1008 DataValue::Integer(4),
1009 ]))
1010 .unwrap();
1011
1012 table
1013 }
1014
1015 #[test]
1016 fn test_evaluate_column() {
1017 let table = create_test_table();
1018 let mut evaluator = ArithmeticEvaluator::new(&table);
1019
1020 let expr = SqlExpression::Column("a".to_string());
1021 let result = evaluator.evaluate(&expr, 0).unwrap();
1022 assert_eq!(result, DataValue::Integer(10));
1023 }
1024
1025 #[test]
1026 fn test_evaluate_number_literal() {
1027 let table = create_test_table();
1028 let mut evaluator = ArithmeticEvaluator::new(&table);
1029
1030 let expr = SqlExpression::NumberLiteral("42".to_string());
1031 let result = evaluator.evaluate(&expr, 0).unwrap();
1032 assert_eq!(result, DataValue::Integer(42));
1033
1034 let expr = SqlExpression::NumberLiteral("3.14".to_string());
1035 let result = evaluator.evaluate(&expr, 0).unwrap();
1036 assert_eq!(result, DataValue::Float(3.14));
1037 }
1038
1039 #[test]
1040 fn test_add_values() {
1041 let table = create_test_table();
1042 let mut evaluator = ArithmeticEvaluator::new(&table);
1043
1044 let result = evaluator
1046 .add_values(&DataValue::Integer(5), &DataValue::Integer(3))
1047 .unwrap();
1048 assert_eq!(result, DataValue::Integer(8));
1049
1050 let result = evaluator
1052 .add_values(&DataValue::Integer(5), &DataValue::Float(2.5))
1053 .unwrap();
1054 assert_eq!(result, DataValue::Float(7.5));
1055 }
1056
1057 #[test]
1058 fn test_multiply_values() {
1059 let table = create_test_table();
1060 let mut evaluator = ArithmeticEvaluator::new(&table);
1061
1062 let result = evaluator
1064 .multiply_values(&DataValue::Integer(4), &DataValue::Float(2.5))
1065 .unwrap();
1066 assert_eq!(result, DataValue::Float(10.0));
1067 }
1068
1069 #[test]
1070 fn test_divide_values() {
1071 let table = create_test_table();
1072 let mut evaluator = ArithmeticEvaluator::new(&table);
1073
1074 let result = evaluator
1076 .divide_values(&DataValue::Integer(10), &DataValue::Integer(2))
1077 .unwrap();
1078 assert_eq!(result, DataValue::Integer(5));
1079
1080 let result = evaluator
1082 .divide_values(&DataValue::Integer(10), &DataValue::Integer(3))
1083 .unwrap();
1084 assert_eq!(result, DataValue::Float(10.0 / 3.0));
1085 }
1086
1087 #[test]
1088 fn test_division_by_zero() {
1089 let table = create_test_table();
1090 let mut evaluator = ArithmeticEvaluator::new(&table);
1091
1092 let result = evaluator.divide_values(&DataValue::Integer(10), &DataValue::Integer(0));
1093 assert!(result.is_err());
1094 assert!(result.unwrap_err().to_string().contains("Division by zero"));
1095 }
1096
1097 #[test]
1098 fn test_binary_op_expression() {
1099 let table = create_test_table();
1100 let mut evaluator = ArithmeticEvaluator::new(&table);
1101
1102 let expr = SqlExpression::BinaryOp {
1104 left: Box::new(SqlExpression::Column("a".to_string())),
1105 op: "*".to_string(),
1106 right: Box::new(SqlExpression::Column("b".to_string())),
1107 };
1108
1109 let result = evaluator.evaluate(&expr, 0).unwrap();
1110 assert_eq!(result, DataValue::Float(25.0));
1111 }
1112}