1use crate::data::data_view::DataView;
2use crate::data::datatable::{DataTable, DataValue};
3use crate::sql::aggregates::AggregateRegistry;
4use crate::sql::functions::FunctionRegistry;
5use crate::sql::recursive_parser::{SqlExpression, WindowSpec};
6use crate::sql::window_context::WindowContext;
7use anyhow::{anyhow, Result};
8use std::collections::HashMap;
9use std::sync::Arc;
10use tracing::debug;
11
12pub struct ArithmeticEvaluator<'a> {
15 table: &'a DataTable,
16 date_notation: String,
17 function_registry: Arc<FunctionRegistry>,
18 aggregate_registry: Arc<AggregateRegistry>,
19 visible_rows: Option<Vec<usize>>, window_contexts: HashMap<String, Arc<WindowContext>>, }
22
23impl<'a> ArithmeticEvaluator<'a> {
24 #[must_use]
25 pub fn new(table: &'a DataTable) -> Self {
26 Self {
27 table,
28 date_notation: "us".to_string(),
29 function_registry: Arc::new(FunctionRegistry::new()),
30 aggregate_registry: Arc::new(AggregateRegistry::new()),
31 visible_rows: None,
32 window_contexts: HashMap::new(),
33 }
34 }
35
36 #[must_use]
37 pub fn with_date_notation(table: &'a DataTable, date_notation: String) -> Self {
38 Self {
39 table,
40 date_notation,
41 function_registry: Arc::new(FunctionRegistry::new()),
42 aggregate_registry: Arc::new(AggregateRegistry::new()),
43 visible_rows: None,
44 window_contexts: HashMap::new(),
45 }
46 }
47
48 #[must_use]
50 pub fn with_visible_rows(mut self, rows: Vec<usize>) -> Self {
51 self.visible_rows = Some(rows);
52 self
53 }
54
55 #[must_use]
56 pub fn with_date_notation_and_registry(
57 table: &'a DataTable,
58 date_notation: String,
59 function_registry: Arc<FunctionRegistry>,
60 ) -> Self {
61 Self {
62 table,
63 date_notation,
64 function_registry,
65 aggregate_registry: Arc::new(AggregateRegistry::new()),
66 visible_rows: None,
67 window_contexts: HashMap::new(),
68 }
69 }
70
71 fn find_similar_column(&self, name: &str) -> Option<String> {
73 let columns = self.table.column_names();
74 let mut best_match: Option<(String, usize)> = None;
75
76 for col in columns {
77 let distance = self.edit_distance(&col.to_lowercase(), &name.to_lowercase());
78 let max_distance = if name.len() > 10 { 3 } else { 2 };
81 if distance <= max_distance {
82 match &best_match {
83 None => best_match = Some((col, distance)),
84 Some((_, best_dist)) if distance < *best_dist => {
85 best_match = Some((col, distance));
86 }
87 _ => {}
88 }
89 }
90 }
91
92 best_match.map(|(name, _)| name)
93 }
94
95 fn edit_distance(&self, s1: &str, s2: &str) -> usize {
97 crate::sql::functions::string_methods::EditDistanceFunction::calculate_edit_distance(s1, s2)
99 }
100
101 pub fn evaluate(&mut self, expr: &SqlExpression, row_index: usize) -> Result<DataValue> {
103 debug!(
104 "ArithmeticEvaluator: evaluating {:?} for row {}",
105 expr, row_index
106 );
107
108 match expr {
109 SqlExpression::Column(column_name) => self.evaluate_column(column_name, row_index),
110 SqlExpression::StringLiteral(s) => Ok(DataValue::String(s.clone())),
111 SqlExpression::BooleanLiteral(b) => Ok(DataValue::Boolean(*b)),
112 SqlExpression::NumberLiteral(n) => self.evaluate_number_literal(n),
113 SqlExpression::BinaryOp { left, op, right } => {
114 self.evaluate_binary_op(left, op, right, row_index)
115 }
116 SqlExpression::FunctionCall { name, args } => {
117 self.evaluate_function(name, args, row_index)
118 }
119 SqlExpression::WindowFunction {
120 name,
121 args,
122 window_spec,
123 } => self.evaluate_window_function(name, args, window_spec, row_index),
124 SqlExpression::MethodCall {
125 object,
126 method,
127 args,
128 } => self.evaluate_method_call(object, method, args, row_index),
129 SqlExpression::ChainedMethodCall { base, method, args } => {
130 let base_value = self.evaluate(base, row_index)?;
132 self.evaluate_method_on_value(&base_value, method, args, row_index)
133 }
134 SqlExpression::CaseExpression {
135 when_branches,
136 else_branch,
137 } => self.evaluate_case_expression(when_branches, else_branch, row_index),
138 _ => Err(anyhow!(
139 "Unsupported expression type for arithmetic evaluation: {:?}",
140 expr
141 )),
142 }
143 }
144
145 fn evaluate_column(&self, column_name: &str, row_index: usize) -> Result<DataValue> {
147 let col_index = self.table.get_column_index(column_name).ok_or_else(|| {
148 let suggestion = self.find_similar_column(column_name);
149 match suggestion {
150 Some(similar) => anyhow!(
151 "Column '{}' not found. Did you mean '{}'?",
152 column_name,
153 similar
154 ),
155 None => anyhow!("Column '{}' not found", column_name),
156 }
157 })?;
158
159 if row_index >= self.table.row_count() {
160 return Err(anyhow!("Row index {} out of bounds", row_index));
161 }
162
163 let row = self
164 .table
165 .get_row(row_index)
166 .ok_or_else(|| anyhow!("Row {} not found", row_index))?;
167
168 let value = row
169 .get(col_index)
170 .ok_or_else(|| anyhow!("Column index {} out of bounds for row", col_index))?;
171
172 Ok(value.clone())
173 }
174
175 fn evaluate_number_literal(&self, number_str: &str) -> Result<DataValue> {
177 if let Ok(int_val) = number_str.parse::<i64>() {
179 return Ok(DataValue::Integer(int_val));
180 }
181
182 if let Ok(float_val) = number_str.parse::<f64>() {
184 return Ok(DataValue::Float(float_val));
185 }
186
187 Err(anyhow!("Invalid number literal: {}", number_str))
188 }
189
190 fn evaluate_binary_op(
192 &mut self,
193 left: &SqlExpression,
194 op: &str,
195 right: &SqlExpression,
196 row_index: usize,
197 ) -> Result<DataValue> {
198 let left_val = self.evaluate(left, row_index)?;
199 let right_val = self.evaluate(right, row_index)?;
200
201 debug!(
202 "ArithmeticEvaluator: {} {} {}",
203 self.format_value(&left_val),
204 op,
205 self.format_value(&right_val)
206 );
207
208 match op {
209 "+" => self.add_values(&left_val, &right_val),
210 "-" => self.subtract_values(&left_val, &right_val),
211 "*" => self.multiply_values(&left_val, &right_val),
212 "/" => self.divide_values(&left_val, &right_val),
213 ">" => self.compare_values(&left_val, &right_val, |a, b| a > b),
215 "<" => self.compare_values(&left_val, &right_val, |a, b| a < b),
216 ">=" => self.compare_values(&left_val, &right_val, |a, b| a >= b),
217 "<=" => self.compare_values(&left_val, &right_val, |a, b| a <= b),
218 "=" => self.compare_values(&left_val, &right_val, |a, b| a == b),
219 "!=" | "<>" => self.compare_values(&left_val, &right_val, |a, b| a != b),
220 _ => Err(anyhow!("Unsupported arithmetic operator: {}", op)),
221 }
222 }
223
224 fn add_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
226 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
228 return Ok(DataValue::Null);
229 }
230
231 match (left, right) {
232 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a + b)),
233 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 + b)),
234 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a + *b as f64)),
235 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a + b)),
236 _ => Err(anyhow!("Cannot add {:?} and {:?}", left, right)),
237 }
238 }
239
240 fn subtract_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
242 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
244 return Ok(DataValue::Null);
245 }
246
247 match (left, right) {
248 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a - b)),
249 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 - b)),
250 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a - *b as f64)),
251 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a - b)),
252 _ => Err(anyhow!("Cannot subtract {:?} and {:?}", left, right)),
253 }
254 }
255
256 fn multiply_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
258 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
260 return Ok(DataValue::Null);
261 }
262
263 match (left, right) {
264 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a * b)),
265 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 * b)),
266 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a * *b as f64)),
267 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a * b)),
268 _ => Err(anyhow!("Cannot multiply {:?} and {:?}", left, right)),
269 }
270 }
271
272 fn divide_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
274 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
276 return Ok(DataValue::Null);
277 }
278
279 let is_zero = match right {
281 DataValue::Integer(0) => true,
282 DataValue::Float(f) if *f == 0.0 => true, _ => false,
284 };
285
286 if is_zero {
287 return Err(anyhow!("Division by zero"));
288 }
289
290 match (left, right) {
291 (DataValue::Integer(a), DataValue::Integer(b)) => {
292 if a % b == 0 {
294 Ok(DataValue::Integer(a / b))
295 } else {
296 Ok(DataValue::Float(*a as f64 / *b as f64))
297 }
298 }
299 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 / b)),
300 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a / *b as f64)),
301 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a / b)),
302 _ => Err(anyhow!("Cannot divide {:?} and {:?}", left, right)),
303 }
304 }
305
306 fn format_value(&self, value: &DataValue) -> String {
308 match value {
309 DataValue::Integer(i) => i.to_string(),
310 DataValue::Float(f) => f.to_string(),
311 DataValue::String(s) => format!("'{s}'"),
312 _ => format!("{value:?}"),
313 }
314 }
315
316 fn compare_values<F>(&self, left: &DataValue, right: &DataValue, op: F) -> Result<DataValue>
318 where
319 F: Fn(f64, f64) -> bool,
320 {
321 debug!(
322 "ArithmeticEvaluator: comparing values {:?} and {:?}",
323 left, right
324 );
325
326 let result = match (left, right) {
327 (DataValue::Integer(a), DataValue::Integer(b)) => op(*a as f64, *b as f64),
329 (DataValue::Integer(a), DataValue::Float(b)) => op(*a as f64, *b),
330 (DataValue::Float(a), DataValue::Integer(b)) => op(*a, *b as f64),
331 (DataValue::Float(a), DataValue::Float(b)) => op(*a, *b),
332
333 (DataValue::String(a), DataValue::String(b)) => {
335 let a_num = a.parse::<f64>();
336 let b_num = b.parse::<f64>();
337 match (a_num, b_num) {
338 (Ok(a_val), Ok(b_val)) => op(a_val, b_val), _ => op(a.len() as f64, b.len() as f64), }
341 }
342 (DataValue::InternedString(a), DataValue::InternedString(b)) => {
343 let a_num = a.parse::<f64>();
344 let b_num = b.parse::<f64>();
345 match (a_num, b_num) {
346 (Ok(a_val), Ok(b_val)) => op(a_val, b_val), _ => op(a.len() as f64, b.len() as f64), }
349 }
350 (DataValue::String(a), DataValue::InternedString(b)) => {
351 let a_num = a.parse::<f64>();
352 let b_num = b.parse::<f64>();
353 match (a_num, b_num) {
354 (Ok(a_val), Ok(b_val)) => op(a_val, b_val), _ => op(a.len() as f64, b.len() as f64), }
357 }
358 (DataValue::InternedString(a), DataValue::String(b)) => {
359 let a_num = a.parse::<f64>();
360 let b_num = b.parse::<f64>();
361 match (a_num, b_num) {
362 (Ok(a_val), Ok(b_val)) => op(a_val, b_val), _ => op(a.len() as f64, b.len() as f64), }
365 }
366
367 (DataValue::String(a), DataValue::Integer(b)) => {
369 match a.parse::<f64>() {
370 Ok(a_val) => op(a_val, *b as f64),
371 Err(_) => false, }
373 }
374 (DataValue::Integer(a), DataValue::String(b)) => {
375 match b.parse::<f64>() {
376 Ok(b_val) => op(*a as f64, b_val),
377 Err(_) => false, }
379 }
380 (DataValue::String(a), DataValue::Float(b)) => match a.parse::<f64>() {
381 Ok(a_val) => op(a_val, *b),
382 Err(_) => false,
383 },
384 (DataValue::Float(a), DataValue::String(b)) => match b.parse::<f64>() {
385 Ok(b_val) => op(*a, b_val),
386 Err(_) => false,
387 },
388
389 (DataValue::Null, _) | (_, DataValue::Null) => false,
391
392 (DataValue::Boolean(a), DataValue::Boolean(b)) => {
394 op(if *a { 1.0 } else { 0.0 }, if *b { 1.0 } else { 0.0 })
395 }
396
397 _ => {
398 debug!(
399 "ArithmeticEvaluator: unsupported comparison between {:?} and {:?}",
400 left, right
401 );
402 false
403 }
404 };
405
406 debug!("ArithmeticEvaluator: comparison result: {}", result);
407 Ok(DataValue::Boolean(result))
408 }
409
410 fn evaluate_function(
412 &mut self,
413 name: &str,
414 args: &[SqlExpression],
415 row_index: usize,
416 ) -> Result<DataValue> {
417 let name_upper = name.to_uppercase();
419
420 if name_upper == "COUNT" && args.len() == 1 {
422 match &args[0] {
423 SqlExpression::Column(col) if col == "*" => {
424 let count = if let Some(ref visible) = self.visible_rows {
426 visible.len() as i64
427 } else {
428 self.table.rows.len() as i64
429 };
430 return Ok(DataValue::Integer(count));
431 }
432 SqlExpression::StringLiteral(s) if s == "*" => {
433 let count = if let Some(ref visible) = self.visible_rows {
435 visible.len() as i64
436 } else {
437 self.table.rows.len() as i64
438 };
439 return Ok(DataValue::Integer(count));
440 }
441 _ => {
442 }
444 }
445 }
446
447 if self.aggregate_registry.get(&name_upper).is_some() {
449 let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
451 visible.clone()
452 } else {
453 (0..self.table.rows.len()).collect()
454 };
455
456 let values = if !args.is_empty()
458 && !(args.len() == 1 && matches!(&args[0], SqlExpression::Column(c) if c == "*"))
459 {
460 let mut vals = Vec::new();
462 for &row_idx in &rows_to_process {
463 let value = self.evaluate(&args[0], row_idx)?;
464 vals.push(value);
465 }
466 Some(vals)
467 } else {
468 None
469 };
470
471 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
473 let mut state = agg_func.init();
474
475 if let Some(values) = values {
476 for value in &values {
478 agg_func.accumulate(&mut state, value)?;
479 }
480 } else {
481 for _ in &rows_to_process {
483 agg_func.accumulate(&mut state, &DataValue::Integer(1))?;
484 }
485 }
486
487 return Ok(agg_func.finalize(state));
488 }
489
490 if self.function_registry.get(name).is_some() {
492 let mut evaluated_args = Vec::new();
494 for arg in args {
495 evaluated_args.push(self.evaluate(arg, row_index)?);
496 }
497
498 let func = self.function_registry.get(name).unwrap();
500 return func.evaluate(&evaluated_args);
501 }
502
503 match name_upper.as_str() {
507 "CONVERT" => {
508 if args.len() != 3 {
509 return Err(anyhow!(
510 "CONVERT requires exactly 3 arguments: value, from_unit, to_unit"
511 ));
512 }
513
514 let value = self.evaluate(&args[0], row_index)?;
516 let numeric_value = match value {
517 DataValue::Integer(n) => n as f64,
518 DataValue::Float(f) => f,
519 _ => return Err(anyhow!("CONVERT first argument must be numeric")),
520 };
521
522 let from_unit = match self.evaluate(&args[1], row_index)? {
524 DataValue::String(s) => s,
525 DataValue::InternedString(s) => s.to_string(),
526 _ => {
527 return Err(anyhow!(
528 "CONVERT second argument must be a string (from_unit)"
529 ))
530 }
531 };
532
533 let to_unit = match self.evaluate(&args[2], row_index)? {
534 DataValue::String(s) => s,
535 DataValue::InternedString(s) => s.to_string(),
536 _ => return Err(anyhow!("CONVERT third argument must be a string (to_unit)")),
537 };
538
539 match crate::data::unit_converter::convert_units(
541 numeric_value,
542 &from_unit,
543 &to_unit,
544 ) {
545 Ok(result) => Ok(DataValue::Float(result)),
546 Err(e) => Err(anyhow!("Unit conversion error: {}", e)),
547 }
548 }
549 _ => Err(anyhow!("Unknown function: {}", name)),
550 }
551 }
552
553 fn get_or_create_window_context(&mut self, spec: &WindowSpec) -> Result<Arc<WindowContext>> {
555 let key = format!("{:?}", spec);
557
558 if let Some(context) = self.window_contexts.get(&key) {
559 return Ok(Arc::clone(context));
560 }
561
562 let data_view = if let Some(ref visible_rows) = self.visible_rows {
564 let mut view = DataView::new(Arc::new(self.table.clone()));
566 view
569 } else {
570 DataView::new(Arc::new(self.table.clone()))
571 };
572
573 let context = WindowContext::new(
575 Arc::new(data_view),
576 spec.partition_by.clone(),
577 spec.order_by.clone(),
578 )?;
579
580 let context = Arc::new(context);
581 self.window_contexts.insert(key, Arc::clone(&context));
582 Ok(context)
583 }
584
585 fn evaluate_window_function(
587 &mut self,
588 name: &str,
589 args: &[SqlExpression],
590 spec: &WindowSpec,
591 row_index: usize,
592 ) -> Result<DataValue> {
593 let context = self.get_or_create_window_context(spec)?;
594 let name_upper = name.to_uppercase();
595
596 match name_upper.as_str() {
597 "LAG" => {
598 if args.is_empty() {
600 return Err(anyhow!("LAG requires at least 1 argument"));
601 }
602
603 let column = match &args[0] {
605 SqlExpression::Column(col) => col.clone(),
606 _ => return Err(anyhow!("LAG first argument must be a column")),
607 };
608
609 let offset = if args.len() > 1 {
611 match self.evaluate(&args[1], row_index)? {
612 DataValue::Integer(i) => i as i32,
613 _ => return Err(anyhow!("LAG offset must be an integer")),
614 }
615 } else {
616 1
617 };
618
619 Ok(context
621 .get_offset_value(row_index, -offset, &column)
622 .unwrap_or(DataValue::Null))
623 }
624 "LEAD" => {
625 if args.is_empty() {
627 return Err(anyhow!("LEAD requires at least 1 argument"));
628 }
629
630 let column = match &args[0] {
632 SqlExpression::Column(col) => col.clone(),
633 _ => return Err(anyhow!("LEAD first argument must be a column")),
634 };
635
636 let offset = if args.len() > 1 {
638 match self.evaluate(&args[1], row_index)? {
639 DataValue::Integer(i) => i as i32,
640 _ => return Err(anyhow!("LEAD offset must be an integer")),
641 }
642 } else {
643 1
644 };
645
646 Ok(context
648 .get_offset_value(row_index, offset, &column)
649 .unwrap_or(DataValue::Null))
650 }
651 "ROW_NUMBER" => {
652 Ok(DataValue::Integer(context.get_row_number(row_index) as i64))
654 }
655 "FIRST_VALUE" => {
656 if args.is_empty() {
658 return Err(anyhow!("FIRST_VALUE requires 1 argument"));
659 }
660
661 let column = match &args[0] {
662 SqlExpression::Column(col) => col.clone(),
663 _ => return Err(anyhow!("FIRST_VALUE argument must be a column")),
664 };
665
666 Ok(context
667 .get_first_value(row_index, &column)
668 .unwrap_or(DataValue::Null))
669 }
670 "LAST_VALUE" => {
671 if args.is_empty() {
673 return Err(anyhow!("LAST_VALUE requires 1 argument"));
674 }
675
676 let column = match &args[0] {
677 SqlExpression::Column(col) => col.clone(),
678 _ => return Err(anyhow!("LAST_VALUE argument must be a column")),
679 };
680
681 Ok(context
682 .get_last_value(row_index, &column)
683 .unwrap_or(DataValue::Null))
684 }
685 _ => Err(anyhow!("Unknown window function: {}", name)),
686 }
687 }
688
689 fn evaluate_method_call(
691 &mut self,
692 object: &str,
693 method: &str,
694 args: &[SqlExpression],
695 row_index: usize,
696 ) -> Result<DataValue> {
697 let col_index = self.table.get_column_index(object).ok_or_else(|| {
699 let suggestion = self.find_similar_column(object);
700 match suggestion {
701 Some(similar) => {
702 anyhow!("Column '{}' not found. Did you mean '{}'?", object, similar)
703 }
704 None => anyhow!("Column '{}' not found", object),
705 }
706 })?;
707
708 let cell_value = self.table.get_value(row_index, col_index).cloned();
709
710 self.evaluate_method_on_value(
711 &cell_value.unwrap_or(DataValue::Null),
712 method,
713 args,
714 row_index,
715 )
716 }
717
718 fn evaluate_method_on_value(
720 &mut self,
721 value: &DataValue,
722 method: &str,
723 args: &[SqlExpression],
724 row_index: usize,
725 ) -> Result<DataValue> {
726 let function_name = match method.to_lowercase().as_str() {
731 "trim" => "TRIM",
732 "trimstart" | "trimbegin" => "TRIMSTART",
733 "trimend" => "TRIMEND",
734 "length" | "len" => "LENGTH",
735 "contains" => "CONTAINS",
736 "startswith" => "STARTSWITH",
737 "endswith" => "ENDSWITH",
738 "indexof" => "INDEXOF",
739 _ => method, };
741
742 if self.function_registry.get(function_name).is_some() {
744 debug!(
745 "Proxying method '{}' through function registry as '{}'",
746 method, function_name
747 );
748
749 let mut func_args = vec![value.clone()];
751
752 for arg in args {
754 func_args.push(self.evaluate(arg, row_index)?);
755 }
756
757 let func = self.function_registry.get(function_name).unwrap();
759 return func.evaluate(&func_args);
760 }
761
762 match method.to_lowercase().as_str() {
764 "trim" | "trimstart" | "trimend" => {
765 if !args.is_empty() {
766 return Err(anyhow!("{} takes no arguments", method));
767 }
768
769 let str_val = match value {
771 DataValue::String(s) => s.clone(),
772 DataValue::InternedString(s) => s.to_string(),
773 DataValue::Integer(n) => n.to_string(),
774 DataValue::Float(f) => f.to_string(),
775 DataValue::Boolean(b) => b.to_string(),
776 DataValue::DateTime(dt) => dt.clone(),
777 DataValue::Null => return Ok(DataValue::Null),
778 };
779
780 let result = match method.to_lowercase().as_str() {
781 "trim" => str_val.trim().to_string(),
782 "trimstart" => str_val.trim_start().to_string(),
783 "trimend" => str_val.trim_end().to_string(),
784 _ => unreachable!(),
785 };
786
787 Ok(DataValue::String(result))
788 }
789 "length" => {
790 if !args.is_empty() {
791 return Err(anyhow!("Length takes no arguments"));
792 }
793
794 let len = match value {
796 DataValue::String(s) => s.len(),
797 DataValue::InternedString(s) => s.len(),
798 DataValue::Integer(n) => n.to_string().len(),
799 DataValue::Float(f) => f.to_string().len(),
800 DataValue::Boolean(b) => b.to_string().len(),
801 DataValue::DateTime(dt) => dt.len(),
802 DataValue::Null => return Ok(DataValue::Integer(0)),
803 };
804
805 Ok(DataValue::Integer(len as i64))
806 }
807 "indexof" => {
808 if args.len() != 1 {
809 return Err(anyhow!("IndexOf requires exactly 1 argument"));
810 }
811
812 let search_str = match self.evaluate(&args[0], row_index)? {
814 DataValue::String(s) => s,
815 DataValue::InternedString(s) => s.to_string(),
816 DataValue::Integer(n) => n.to_string(),
817 DataValue::Float(f) => f.to_string(),
818 _ => return Err(anyhow!("IndexOf argument must be a string")),
819 };
820
821 let str_val = match value {
823 DataValue::String(s) => s.clone(),
824 DataValue::InternedString(s) => s.to_string(),
825 DataValue::Integer(n) => n.to_string(),
826 DataValue::Float(f) => f.to_string(),
827 DataValue::Boolean(b) => b.to_string(),
828 DataValue::DateTime(dt) => dt.clone(),
829 DataValue::Null => return Ok(DataValue::Integer(-1)),
830 };
831
832 let index = str_val.find(&search_str).map_or(-1, |i| i as i64);
833
834 Ok(DataValue::Integer(index))
835 }
836 "contains" => {
837 if args.len() != 1 {
838 return Err(anyhow!("Contains requires exactly 1 argument"));
839 }
840
841 let search_str = match self.evaluate(&args[0], row_index)? {
843 DataValue::String(s) => s,
844 DataValue::InternedString(s) => s.to_string(),
845 DataValue::Integer(n) => n.to_string(),
846 DataValue::Float(f) => f.to_string(),
847 _ => return Err(anyhow!("Contains argument must be a string")),
848 };
849
850 let str_val = match value {
852 DataValue::String(s) => s.clone(),
853 DataValue::InternedString(s) => s.to_string(),
854 DataValue::Integer(n) => n.to_string(),
855 DataValue::Float(f) => f.to_string(),
856 DataValue::Boolean(b) => b.to_string(),
857 DataValue::DateTime(dt) => dt.clone(),
858 DataValue::Null => return Ok(DataValue::Boolean(false)),
859 };
860
861 let result = str_val.to_lowercase().contains(&search_str.to_lowercase());
863 Ok(DataValue::Boolean(result))
864 }
865 "startswith" => {
866 if args.len() != 1 {
867 return Err(anyhow!("StartsWith requires exactly 1 argument"));
868 }
869
870 let prefix = match self.evaluate(&args[0], row_index)? {
872 DataValue::String(s) => s,
873 DataValue::InternedString(s) => s.to_string(),
874 DataValue::Integer(n) => n.to_string(),
875 DataValue::Float(f) => f.to_string(),
876 _ => return Err(anyhow!("StartsWith argument must be a string")),
877 };
878
879 let str_val = match value {
881 DataValue::String(s) => s.clone(),
882 DataValue::InternedString(s) => s.to_string(),
883 DataValue::Integer(n) => n.to_string(),
884 DataValue::Float(f) => f.to_string(),
885 DataValue::Boolean(b) => b.to_string(),
886 DataValue::DateTime(dt) => dt.clone(),
887 DataValue::Null => return Ok(DataValue::Boolean(false)),
888 };
889
890 let result = str_val.to_lowercase().starts_with(&prefix.to_lowercase());
892 Ok(DataValue::Boolean(result))
893 }
894 "endswith" => {
895 if args.len() != 1 {
896 return Err(anyhow!("EndsWith requires exactly 1 argument"));
897 }
898
899 let suffix = match self.evaluate(&args[0], row_index)? {
901 DataValue::String(s) => s,
902 DataValue::InternedString(s) => s.to_string(),
903 DataValue::Integer(n) => n.to_string(),
904 DataValue::Float(f) => f.to_string(),
905 _ => return Err(anyhow!("EndsWith argument must be a string")),
906 };
907
908 let str_val = match value {
910 DataValue::String(s) => s.clone(),
911 DataValue::InternedString(s) => s.to_string(),
912 DataValue::Integer(n) => n.to_string(),
913 DataValue::Float(f) => f.to_string(),
914 DataValue::Boolean(b) => b.to_string(),
915 DataValue::DateTime(dt) => dt.clone(),
916 DataValue::Null => return Ok(DataValue::Boolean(false)),
917 };
918
919 let result = str_val.to_lowercase().ends_with(&suffix.to_lowercase());
921 Ok(DataValue::Boolean(result))
922 }
923 _ => Err(anyhow!("Unsupported method: {}", method)),
924 }
925 }
926
927 fn evaluate_case_expression(
929 &mut self,
930 when_branches: &[crate::sql::recursive_parser::WhenBranch],
931 else_branch: &Option<Box<SqlExpression>>,
932 row_index: usize,
933 ) -> Result<DataValue> {
934 debug!(
935 "ArithmeticEvaluator: evaluating CASE expression for row {}",
936 row_index
937 );
938
939 for branch in when_branches {
941 let condition_result = self.evaluate_condition_as_bool(&branch.condition, row_index)?;
943
944 if condition_result {
945 debug!("CASE: WHEN condition matched, evaluating result expression");
946 return self.evaluate(&branch.result, row_index);
947 }
948 }
949
950 if let Some(else_expr) = else_branch {
952 debug!("CASE: No WHEN matched, evaluating ELSE expression");
953 self.evaluate(else_expr, row_index)
954 } else {
955 debug!("CASE: No WHEN matched and no ELSE, returning NULL");
956 Ok(DataValue::Null)
957 }
958 }
959
960 fn evaluate_condition_as_bool(
962 &mut self,
963 expr: &SqlExpression,
964 row_index: usize,
965 ) -> Result<bool> {
966 let value = self.evaluate(expr, row_index)?;
967
968 match value {
969 DataValue::Boolean(b) => Ok(b),
970 DataValue::Integer(i) => Ok(i != 0),
971 DataValue::Float(f) => Ok(f != 0.0),
972 DataValue::Null => Ok(false),
973 DataValue::String(s) => Ok(!s.is_empty()),
974 DataValue::InternedString(s) => Ok(!s.is_empty()),
975 _ => Ok(true), }
977 }
978}
979
980#[cfg(test)]
981mod tests {
982 use super::*;
983 use crate::data::datatable::{DataColumn, DataRow};
984
985 fn create_test_table() -> DataTable {
986 let mut table = DataTable::new("test");
987 table.add_column(DataColumn::new("a"));
988 table.add_column(DataColumn::new("b"));
989 table.add_column(DataColumn::new("c"));
990
991 table
992 .add_row(DataRow::new(vec![
993 DataValue::Integer(10),
994 DataValue::Float(2.5),
995 DataValue::Integer(4),
996 ]))
997 .unwrap();
998
999 table
1000 }
1001
1002 #[test]
1003 fn test_evaluate_column() {
1004 let table = create_test_table();
1005 let mut evaluator = ArithmeticEvaluator::new(&table);
1006
1007 let expr = SqlExpression::Column("a".to_string());
1008 let result = evaluator.evaluate(&expr, 0).unwrap();
1009 assert_eq!(result, DataValue::Integer(10));
1010 }
1011
1012 #[test]
1013 fn test_evaluate_number_literal() {
1014 let table = create_test_table();
1015 let mut evaluator = ArithmeticEvaluator::new(&table);
1016
1017 let expr = SqlExpression::NumberLiteral("42".to_string());
1018 let result = evaluator.evaluate(&expr, 0).unwrap();
1019 assert_eq!(result, DataValue::Integer(42));
1020
1021 let expr = SqlExpression::NumberLiteral("3.14".to_string());
1022 let result = evaluator.evaluate(&expr, 0).unwrap();
1023 assert_eq!(result, DataValue::Float(3.14));
1024 }
1025
1026 #[test]
1027 fn test_add_values() {
1028 let table = create_test_table();
1029 let mut evaluator = ArithmeticEvaluator::new(&table);
1030
1031 let result = evaluator
1033 .add_values(&DataValue::Integer(5), &DataValue::Integer(3))
1034 .unwrap();
1035 assert_eq!(result, DataValue::Integer(8));
1036
1037 let result = evaluator
1039 .add_values(&DataValue::Integer(5), &DataValue::Float(2.5))
1040 .unwrap();
1041 assert_eq!(result, DataValue::Float(7.5));
1042 }
1043
1044 #[test]
1045 fn test_multiply_values() {
1046 let table = create_test_table();
1047 let mut evaluator = ArithmeticEvaluator::new(&table);
1048
1049 let result = evaluator
1051 .multiply_values(&DataValue::Integer(4), &DataValue::Float(2.5))
1052 .unwrap();
1053 assert_eq!(result, DataValue::Float(10.0));
1054 }
1055
1056 #[test]
1057 fn test_divide_values() {
1058 let table = create_test_table();
1059 let mut evaluator = ArithmeticEvaluator::new(&table);
1060
1061 let result = evaluator
1063 .divide_values(&DataValue::Integer(10), &DataValue::Integer(2))
1064 .unwrap();
1065 assert_eq!(result, DataValue::Integer(5));
1066
1067 let result = evaluator
1069 .divide_values(&DataValue::Integer(10), &DataValue::Integer(3))
1070 .unwrap();
1071 assert_eq!(result, DataValue::Float(10.0 / 3.0));
1072 }
1073
1074 #[test]
1075 fn test_division_by_zero() {
1076 let table = create_test_table();
1077 let mut evaluator = ArithmeticEvaluator::new(&table);
1078
1079 let result = evaluator.divide_values(&DataValue::Integer(10), &DataValue::Integer(0));
1080 assert!(result.is_err());
1081 assert!(result.unwrap_err().to_string().contains("Division by zero"));
1082 }
1083
1084 #[test]
1085 fn test_binary_op_expression() {
1086 let table = create_test_table();
1087 let mut evaluator = ArithmeticEvaluator::new(&table);
1088
1089 let expr = SqlExpression::BinaryOp {
1091 left: Box::new(SqlExpression::Column("a".to_string())),
1092 op: "*".to_string(),
1093 right: Box::new(SqlExpression::Column("b".to_string())),
1094 };
1095
1096 let result = evaluator.evaluate(&expr, 0).unwrap();
1097 assert_eq!(result, DataValue::Float(25.0));
1098 }
1099}