1use crate::data::data_view::DataView;
2use crate::data::datatable::{DataTable, DataValue};
3use crate::sql::aggregates::AggregateRegistry;
4use crate::sql::functions::FunctionRegistry;
5use crate::sql::recursive_parser::{SqlExpression, WindowSpec};
6use crate::sql::window_context::WindowContext;
7use anyhow::{anyhow, Result};
8use std::collections::HashMap;
9use std::sync::Arc;
10use tracing::debug;
11
12pub struct ArithmeticEvaluator<'a> {
15 table: &'a DataTable,
16 date_notation: String,
17 function_registry: Arc<FunctionRegistry>,
18 aggregate_registry: Arc<AggregateRegistry>,
19 visible_rows: Option<Vec<usize>>, window_contexts: HashMap<String, Arc<WindowContext>>, }
22
23impl<'a> ArithmeticEvaluator<'a> {
24 #[must_use]
25 pub fn new(table: &'a DataTable) -> Self {
26 Self {
27 table,
28 date_notation: "us".to_string(),
29 function_registry: Arc::new(FunctionRegistry::new()),
30 aggregate_registry: Arc::new(AggregateRegistry::new()),
31 visible_rows: None,
32 window_contexts: HashMap::new(),
33 }
34 }
35
36 #[must_use]
37 pub fn with_date_notation(table: &'a DataTable, date_notation: String) -> Self {
38 Self {
39 table,
40 date_notation,
41 function_registry: Arc::new(FunctionRegistry::new()),
42 aggregate_registry: Arc::new(AggregateRegistry::new()),
43 visible_rows: None,
44 window_contexts: HashMap::new(),
45 }
46 }
47
48 #[must_use]
50 pub fn with_visible_rows(mut self, rows: Vec<usize>) -> Self {
51 self.visible_rows = Some(rows);
52 self
53 }
54
55 #[must_use]
56 pub fn with_date_notation_and_registry(
57 table: &'a DataTable,
58 date_notation: String,
59 function_registry: Arc<FunctionRegistry>,
60 ) -> Self {
61 Self {
62 table,
63 date_notation,
64 function_registry,
65 aggregate_registry: Arc::new(AggregateRegistry::new()),
66 visible_rows: None,
67 window_contexts: HashMap::new(),
68 }
69 }
70
71 fn find_similar_column(&self, name: &str) -> Option<String> {
73 let columns = self.table.column_names();
74 let mut best_match: Option<(String, usize)> = None;
75
76 for col in columns {
77 let distance = self.edit_distance(&col.to_lowercase(), &name.to_lowercase());
78 let max_distance = if name.len() > 10 { 3 } else { 2 };
81 if distance <= max_distance {
82 match &best_match {
83 None => best_match = Some((col, distance)),
84 Some((_, best_dist)) if distance < *best_dist => {
85 best_match = Some((col, distance));
86 }
87 _ => {}
88 }
89 }
90 }
91
92 best_match.map(|(name, _)| name)
93 }
94
95 fn edit_distance(&self, s1: &str, s2: &str) -> usize {
97 crate::sql::functions::string_methods::EditDistanceFunction::calculate_edit_distance(s1, s2)
99 }
100
101 pub fn evaluate(&mut self, expr: &SqlExpression, row_index: usize) -> Result<DataValue> {
103 debug!(
104 "ArithmeticEvaluator: evaluating {:?} for row {}",
105 expr, row_index
106 );
107
108 match expr {
109 SqlExpression::Column(column_name) => self.evaluate_column(column_name, row_index),
110 SqlExpression::StringLiteral(s) => Ok(DataValue::String(s.clone())),
111 SqlExpression::BooleanLiteral(b) => Ok(DataValue::Boolean(*b)),
112 SqlExpression::NumberLiteral(n) => self.evaluate_number_literal(n),
113 SqlExpression::Null => Ok(DataValue::Null),
114 SqlExpression::BinaryOp { left, op, right } => {
115 self.evaluate_binary_op(left, op, right, row_index)
116 }
117 SqlExpression::FunctionCall { name, args } => {
118 self.evaluate_function(name, args, row_index)
119 }
120 SqlExpression::WindowFunction {
121 name,
122 args,
123 window_spec,
124 } => self.evaluate_window_function(name, args, window_spec, row_index),
125 SqlExpression::MethodCall {
126 object,
127 method,
128 args,
129 } => self.evaluate_method_call(object, method, args, row_index),
130 SqlExpression::ChainedMethodCall { base, method, args } => {
131 let base_value = self.evaluate(base, row_index)?;
133 self.evaluate_method_on_value(&base_value, method, args, row_index)
134 }
135 SqlExpression::CaseExpression {
136 when_branches,
137 else_branch,
138 } => self.evaluate_case_expression(when_branches, else_branch, row_index),
139 _ => Err(anyhow!(
140 "Unsupported expression type for arithmetic evaluation: {:?}",
141 expr
142 )),
143 }
144 }
145
146 fn evaluate_column(&self, column_name: &str, row_index: usize) -> Result<DataValue> {
148 let col_index = self.table.get_column_index(column_name).ok_or_else(|| {
149 let suggestion = self.find_similar_column(column_name);
150 match suggestion {
151 Some(similar) => anyhow!(
152 "Column '{}' not found. Did you mean '{}'?",
153 column_name,
154 similar
155 ),
156 None => anyhow!("Column '{}' not found", column_name),
157 }
158 })?;
159
160 if row_index >= self.table.row_count() {
161 return Err(anyhow!("Row index {} out of bounds", row_index));
162 }
163
164 let row = self
165 .table
166 .get_row(row_index)
167 .ok_or_else(|| anyhow!("Row {} not found", row_index))?;
168
169 let value = row
170 .get(col_index)
171 .ok_or_else(|| anyhow!("Column index {} out of bounds for row", col_index))?;
172
173 Ok(value.clone())
174 }
175
176 fn evaluate_number_literal(&self, number_str: &str) -> Result<DataValue> {
178 if let Ok(int_val) = number_str.parse::<i64>() {
180 return Ok(DataValue::Integer(int_val));
181 }
182
183 if let Ok(float_val) = number_str.parse::<f64>() {
185 return Ok(DataValue::Float(float_val));
186 }
187
188 Err(anyhow!("Invalid number literal: {}", number_str))
189 }
190
191 fn evaluate_binary_op(
193 &mut self,
194 left: &SqlExpression,
195 op: &str,
196 right: &SqlExpression,
197 row_index: usize,
198 ) -> Result<DataValue> {
199 let left_val = self.evaluate(left, row_index)?;
200 let right_val = self.evaluate(right, row_index)?;
201
202 debug!(
203 "ArithmeticEvaluator: {} {} {}",
204 self.format_value(&left_val),
205 op,
206 self.format_value(&right_val)
207 );
208
209 match op {
210 "+" => self.add_values(&left_val, &right_val),
211 "-" => self.subtract_values(&left_val, &right_val),
212 "*" => self.multiply_values(&left_val, &right_val),
213 "/" => self.divide_values(&left_val, &right_val),
214 ">" => self.compare_values(&left_val, &right_val, |a, b| a > b),
216 "<" => self.compare_values(&left_val, &right_val, |a, b| a < b),
217 ">=" => self.compare_values(&left_val, &right_val, |a, b| a >= b),
218 "<=" => self.compare_values(&left_val, &right_val, |a, b| a <= b),
219 "=" => self.compare_values(&left_val, &right_val, |a, b| a == b),
220 "!=" | "<>" => self.compare_values(&left_val, &right_val, |a, b| a != b),
221 "IS NULL" => Ok(DataValue::Boolean(matches!(left_val, DataValue::Null))),
223 "IS NOT NULL" => Ok(DataValue::Boolean(!matches!(left_val, DataValue::Null))),
224 _ => Err(anyhow!("Unsupported arithmetic operator: {}", op)),
225 }
226 }
227
228 fn add_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
230 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
232 return Ok(DataValue::Null);
233 }
234
235 match (left, right) {
236 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a + b)),
237 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 + b)),
238 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a + *b as f64)),
239 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a + b)),
240 _ => Err(anyhow!("Cannot add {:?} and {:?}", left, right)),
241 }
242 }
243
244 fn subtract_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
246 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
248 return Ok(DataValue::Null);
249 }
250
251 match (left, right) {
252 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a - b)),
253 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 - b)),
254 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a - *b as f64)),
255 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a - b)),
256 _ => Err(anyhow!("Cannot subtract {:?} and {:?}", left, right)),
257 }
258 }
259
260 fn multiply_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
262 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
264 return Ok(DataValue::Null);
265 }
266
267 match (left, right) {
268 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a * b)),
269 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 * b)),
270 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a * *b as f64)),
271 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a * b)),
272 _ => Err(anyhow!("Cannot multiply {:?} and {:?}", left, right)),
273 }
274 }
275
276 fn divide_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
278 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
280 return Ok(DataValue::Null);
281 }
282
283 let is_zero = match right {
285 DataValue::Integer(0) => true,
286 DataValue::Float(f) if *f == 0.0 => true, _ => false,
288 };
289
290 if is_zero {
291 return Err(anyhow!("Division by zero"));
292 }
293
294 match (left, right) {
295 (DataValue::Integer(a), DataValue::Integer(b)) => {
296 if a % b == 0 {
298 Ok(DataValue::Integer(a / b))
299 } else {
300 Ok(DataValue::Float(*a as f64 / *b as f64))
301 }
302 }
303 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 / b)),
304 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a / *b as f64)),
305 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a / b)),
306 _ => Err(anyhow!("Cannot divide {:?} and {:?}", left, right)),
307 }
308 }
309
310 fn format_value(&self, value: &DataValue) -> String {
312 match value {
313 DataValue::Integer(i) => i.to_string(),
314 DataValue::Float(f) => f.to_string(),
315 DataValue::String(s) => format!("'{s}'"),
316 _ => format!("{value:?}"),
317 }
318 }
319
320 fn compare_values<F>(&self, left: &DataValue, right: &DataValue, op: F) -> Result<DataValue>
322 where
323 F: Fn(f64, f64) -> bool,
324 {
325 debug!(
326 "ArithmeticEvaluator: comparing values {:?} and {:?}",
327 left, right
328 );
329
330 let result = match (left, right) {
331 (DataValue::Integer(a), DataValue::Integer(b)) => op(*a as f64, *b as f64),
333 (DataValue::Integer(a), DataValue::Float(b)) => op(*a as f64, *b),
334 (DataValue::Float(a), DataValue::Integer(b)) => op(*a, *b as f64),
335 (DataValue::Float(a), DataValue::Float(b)) => op(*a, *b),
336
337 (DataValue::String(a), DataValue::String(b)) => {
339 let a_num = a.parse::<f64>();
340 let b_num = b.parse::<f64>();
341 match (a_num, b_num) {
342 (Ok(a_val), Ok(b_val)) => op(a_val, b_val), _ => op(a.len() as f64, b.len() as f64), }
345 }
346 (DataValue::InternedString(a), DataValue::InternedString(b)) => {
347 let a_num = a.parse::<f64>();
348 let b_num = b.parse::<f64>();
349 match (a_num, b_num) {
350 (Ok(a_val), Ok(b_val)) => op(a_val, b_val), _ => op(a.len() as f64, b.len() as f64), }
353 }
354 (DataValue::String(a), DataValue::InternedString(b)) => {
355 let a_num = a.parse::<f64>();
356 let b_num = b.parse::<f64>();
357 match (a_num, b_num) {
358 (Ok(a_val), Ok(b_val)) => op(a_val, b_val), _ => op(a.len() as f64, b.len() as f64), }
361 }
362 (DataValue::InternedString(a), DataValue::String(b)) => {
363 let a_num = a.parse::<f64>();
364 let b_num = b.parse::<f64>();
365 match (a_num, b_num) {
366 (Ok(a_val), Ok(b_val)) => op(a_val, b_val), _ => op(a.len() as f64, b.len() as f64), }
369 }
370
371 (DataValue::String(a), DataValue::Integer(b)) => {
373 match a.parse::<f64>() {
374 Ok(a_val) => op(a_val, *b as f64),
375 Err(_) => false, }
377 }
378 (DataValue::Integer(a), DataValue::String(b)) => {
379 match b.parse::<f64>() {
380 Ok(b_val) => op(*a as f64, b_val),
381 Err(_) => false, }
383 }
384 (DataValue::String(a), DataValue::Float(b)) => match a.parse::<f64>() {
385 Ok(a_val) => op(a_val, *b),
386 Err(_) => false,
387 },
388 (DataValue::Float(a), DataValue::String(b)) => match b.parse::<f64>() {
389 Ok(b_val) => op(*a, b_val),
390 Err(_) => false,
391 },
392
393 (DataValue::Null, _) | (_, DataValue::Null) => false,
395
396 (DataValue::Boolean(a), DataValue::Boolean(b)) => {
398 op(if *a { 1.0 } else { 0.0 }, if *b { 1.0 } else { 0.0 })
399 }
400
401 _ => {
402 debug!(
403 "ArithmeticEvaluator: unsupported comparison between {:?} and {:?}",
404 left, right
405 );
406 false
407 }
408 };
409
410 debug!("ArithmeticEvaluator: comparison result: {}", result);
411 Ok(DataValue::Boolean(result))
412 }
413
414 fn evaluate_function(
416 &mut self,
417 name: &str,
418 args: &[SqlExpression],
419 row_index: usize,
420 ) -> Result<DataValue> {
421 let name_upper = name.to_uppercase();
423
424 if name_upper == "COUNT" && args.len() == 1 {
426 match &args[0] {
427 SqlExpression::Column(col) if col == "*" => {
428 let count = if let Some(ref visible) = self.visible_rows {
430 visible.len() as i64
431 } else {
432 self.table.rows.len() as i64
433 };
434 return Ok(DataValue::Integer(count));
435 }
436 SqlExpression::StringLiteral(s) if s == "*" => {
437 let count = if let Some(ref visible) = self.visible_rows {
439 visible.len() as i64
440 } else {
441 self.table.rows.len() as i64
442 };
443 return Ok(DataValue::Integer(count));
444 }
445 _ => {
446 }
448 }
449 }
450
451 if self.aggregate_registry.get(&name_upper).is_some() {
453 let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
455 visible.clone()
456 } else {
457 (0..self.table.rows.len()).collect()
458 };
459
460 let values = if !args.is_empty()
462 && !(args.len() == 1 && matches!(&args[0], SqlExpression::Column(c) if c == "*"))
463 {
464 let mut vals = Vec::new();
466 for &row_idx in &rows_to_process {
467 let value = self.evaluate(&args[0], row_idx)?;
468 vals.push(value);
469 }
470 Some(vals)
471 } else {
472 None
473 };
474
475 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
477 let mut state = agg_func.init();
478
479 if let Some(values) = values {
480 for value in &values {
482 agg_func.accumulate(&mut state, value)?;
483 }
484 } else {
485 for _ in &rows_to_process {
487 agg_func.accumulate(&mut state, &DataValue::Integer(1))?;
488 }
489 }
490
491 return Ok(agg_func.finalize(state));
492 }
493
494 if self.function_registry.get(name).is_some() {
496 let mut evaluated_args = Vec::new();
498 for arg in args {
499 evaluated_args.push(self.evaluate(arg, row_index)?);
500 }
501
502 let func = self.function_registry.get(name).unwrap();
504 return func.evaluate(&evaluated_args);
505 }
506
507 Err(anyhow!("Unknown function: {}", name))
509 }
510
511 fn get_or_create_window_context(&mut self, spec: &WindowSpec) -> Result<Arc<WindowContext>> {
513 let key = format!("{:?}", spec);
515
516 if let Some(context) = self.window_contexts.get(&key) {
517 return Ok(Arc::clone(context));
518 }
519
520 let data_view = if let Some(ref visible_rows) = self.visible_rows {
522 let mut view = DataView::new(Arc::new(self.table.clone()));
524 view
527 } else {
528 DataView::new(Arc::new(self.table.clone()))
529 };
530
531 let context = WindowContext::new(
533 Arc::new(data_view),
534 spec.partition_by.clone(),
535 spec.order_by.clone(),
536 )?;
537
538 let context = Arc::new(context);
539 self.window_contexts.insert(key, Arc::clone(&context));
540 Ok(context)
541 }
542
543 fn evaluate_window_function(
545 &mut self,
546 name: &str,
547 args: &[SqlExpression],
548 spec: &WindowSpec,
549 row_index: usize,
550 ) -> Result<DataValue> {
551 let context = self.get_or_create_window_context(spec)?;
552 let name_upper = name.to_uppercase();
553
554 match name_upper.as_str() {
555 "LAG" => {
556 if args.is_empty() {
558 return Err(anyhow!("LAG requires at least 1 argument"));
559 }
560
561 let column = match &args[0] {
563 SqlExpression::Column(col) => col.clone(),
564 _ => return Err(anyhow!("LAG first argument must be a column")),
565 };
566
567 let offset = if args.len() > 1 {
569 match self.evaluate(&args[1], row_index)? {
570 DataValue::Integer(i) => i as i32,
571 _ => return Err(anyhow!("LAG offset must be an integer")),
572 }
573 } else {
574 1
575 };
576
577 Ok(context
579 .get_offset_value(row_index, -offset, &column)
580 .unwrap_or(DataValue::Null))
581 }
582 "LEAD" => {
583 if args.is_empty() {
585 return Err(anyhow!("LEAD requires at least 1 argument"));
586 }
587
588 let column = match &args[0] {
590 SqlExpression::Column(col) => col.clone(),
591 _ => return Err(anyhow!("LEAD first argument must be a column")),
592 };
593
594 let offset = if args.len() > 1 {
596 match self.evaluate(&args[1], row_index)? {
597 DataValue::Integer(i) => i as i32,
598 _ => return Err(anyhow!("LEAD offset must be an integer")),
599 }
600 } else {
601 1
602 };
603
604 Ok(context
606 .get_offset_value(row_index, offset, &column)
607 .unwrap_or(DataValue::Null))
608 }
609 "ROW_NUMBER" => {
610 Ok(DataValue::Integer(context.get_row_number(row_index) as i64))
612 }
613 "FIRST_VALUE" => {
614 if args.is_empty() {
616 return Err(anyhow!("FIRST_VALUE requires 1 argument"));
617 }
618
619 let column = match &args[0] {
620 SqlExpression::Column(col) => col.clone(),
621 _ => return Err(anyhow!("FIRST_VALUE argument must be a column")),
622 };
623
624 Ok(context
625 .get_first_value(row_index, &column)
626 .unwrap_or(DataValue::Null))
627 }
628 "LAST_VALUE" => {
629 if args.is_empty() {
631 return Err(anyhow!("LAST_VALUE requires 1 argument"));
632 }
633
634 let column = match &args[0] {
635 SqlExpression::Column(col) => col.clone(),
636 _ => return Err(anyhow!("LAST_VALUE argument must be a column")),
637 };
638
639 Ok(context
640 .get_last_value(row_index, &column)
641 .unwrap_or(DataValue::Null))
642 }
643 _ => Err(anyhow!("Unknown window function: {}", name)),
644 }
645 }
646
647 fn evaluate_method_call(
649 &mut self,
650 object: &str,
651 method: &str,
652 args: &[SqlExpression],
653 row_index: usize,
654 ) -> Result<DataValue> {
655 let col_index = self.table.get_column_index(object).ok_or_else(|| {
657 let suggestion = self.find_similar_column(object);
658 match suggestion {
659 Some(similar) => {
660 anyhow!("Column '{}' not found. Did you mean '{}'?", object, similar)
661 }
662 None => anyhow!("Column '{}' not found", object),
663 }
664 })?;
665
666 let cell_value = self.table.get_value(row_index, col_index).cloned();
667
668 self.evaluate_method_on_value(
669 &cell_value.unwrap_or(DataValue::Null),
670 method,
671 args,
672 row_index,
673 )
674 }
675
676 fn evaluate_method_on_value(
678 &mut self,
679 value: &DataValue,
680 method: &str,
681 args: &[SqlExpression],
682 row_index: usize,
683 ) -> Result<DataValue> {
684 let function_name = match method.to_lowercase().as_str() {
689 "trim" => "TRIM",
690 "trimstart" | "trimbegin" => "TRIMSTART",
691 "trimend" => "TRIMEND",
692 "length" | "len" => "LENGTH",
693 "contains" => "CONTAINS",
694 "startswith" => "STARTSWITH",
695 "endswith" => "ENDSWITH",
696 "indexof" => "INDEXOF",
697 _ => method, };
699
700 if self.function_registry.get(function_name).is_some() {
702 debug!(
703 "Proxying method '{}' through function registry as '{}'",
704 method, function_name
705 );
706
707 let mut func_args = vec![value.clone()];
709
710 for arg in args {
712 func_args.push(self.evaluate(arg, row_index)?);
713 }
714
715 let func = self.function_registry.get(function_name).unwrap();
717 return func.evaluate(&func_args);
718 }
719
720 match method.to_lowercase().as_str() {
722 "trim" | "trimstart" | "trimend" => {
723 if !args.is_empty() {
724 return Err(anyhow!("{} takes no arguments", method));
725 }
726
727 let str_val = match value {
729 DataValue::String(s) => s.clone(),
730 DataValue::InternedString(s) => s.to_string(),
731 DataValue::Integer(n) => n.to_string(),
732 DataValue::Float(f) => f.to_string(),
733 DataValue::Boolean(b) => b.to_string(),
734 DataValue::DateTime(dt) => dt.clone(),
735 DataValue::Null => return Ok(DataValue::Null),
736 };
737
738 let result = match method.to_lowercase().as_str() {
739 "trim" => str_val.trim().to_string(),
740 "trimstart" => str_val.trim_start().to_string(),
741 "trimend" => str_val.trim_end().to_string(),
742 _ => unreachable!(),
743 };
744
745 Ok(DataValue::String(result))
746 }
747 "length" => {
748 if !args.is_empty() {
749 return Err(anyhow!("Length takes no arguments"));
750 }
751
752 let len = match value {
754 DataValue::String(s) => s.len(),
755 DataValue::InternedString(s) => s.len(),
756 DataValue::Integer(n) => n.to_string().len(),
757 DataValue::Float(f) => f.to_string().len(),
758 DataValue::Boolean(b) => b.to_string().len(),
759 DataValue::DateTime(dt) => dt.len(),
760 DataValue::Null => return Ok(DataValue::Integer(0)),
761 };
762
763 Ok(DataValue::Integer(len as i64))
764 }
765 "indexof" => {
766 if args.len() != 1 {
767 return Err(anyhow!("IndexOf requires exactly 1 argument"));
768 }
769
770 let search_str = match self.evaluate(&args[0], row_index)? {
772 DataValue::String(s) => s,
773 DataValue::InternedString(s) => s.to_string(),
774 DataValue::Integer(n) => n.to_string(),
775 DataValue::Float(f) => f.to_string(),
776 _ => return Err(anyhow!("IndexOf argument must be a string")),
777 };
778
779 let str_val = match value {
781 DataValue::String(s) => s.clone(),
782 DataValue::InternedString(s) => s.to_string(),
783 DataValue::Integer(n) => n.to_string(),
784 DataValue::Float(f) => f.to_string(),
785 DataValue::Boolean(b) => b.to_string(),
786 DataValue::DateTime(dt) => dt.clone(),
787 DataValue::Null => return Ok(DataValue::Integer(-1)),
788 };
789
790 let index = str_val.find(&search_str).map_or(-1, |i| i as i64);
791
792 Ok(DataValue::Integer(index))
793 }
794 "contains" => {
795 if args.len() != 1 {
796 return Err(anyhow!("Contains requires exactly 1 argument"));
797 }
798
799 let search_str = match self.evaluate(&args[0], row_index)? {
801 DataValue::String(s) => s,
802 DataValue::InternedString(s) => s.to_string(),
803 DataValue::Integer(n) => n.to_string(),
804 DataValue::Float(f) => f.to_string(),
805 _ => return Err(anyhow!("Contains argument must be a string")),
806 };
807
808 let str_val = match value {
810 DataValue::String(s) => s.clone(),
811 DataValue::InternedString(s) => s.to_string(),
812 DataValue::Integer(n) => n.to_string(),
813 DataValue::Float(f) => f.to_string(),
814 DataValue::Boolean(b) => b.to_string(),
815 DataValue::DateTime(dt) => dt.clone(),
816 DataValue::Null => return Ok(DataValue::Boolean(false)),
817 };
818
819 let result = str_val.to_lowercase().contains(&search_str.to_lowercase());
821 Ok(DataValue::Boolean(result))
822 }
823 "startswith" => {
824 if args.len() != 1 {
825 return Err(anyhow!("StartsWith requires exactly 1 argument"));
826 }
827
828 let prefix = match self.evaluate(&args[0], row_index)? {
830 DataValue::String(s) => s,
831 DataValue::InternedString(s) => s.to_string(),
832 DataValue::Integer(n) => n.to_string(),
833 DataValue::Float(f) => f.to_string(),
834 _ => return Err(anyhow!("StartsWith argument must be a string")),
835 };
836
837 let str_val = match value {
839 DataValue::String(s) => s.clone(),
840 DataValue::InternedString(s) => s.to_string(),
841 DataValue::Integer(n) => n.to_string(),
842 DataValue::Float(f) => f.to_string(),
843 DataValue::Boolean(b) => b.to_string(),
844 DataValue::DateTime(dt) => dt.clone(),
845 DataValue::Null => return Ok(DataValue::Boolean(false)),
846 };
847
848 let result = str_val.to_lowercase().starts_with(&prefix.to_lowercase());
850 Ok(DataValue::Boolean(result))
851 }
852 "endswith" => {
853 if args.len() != 1 {
854 return Err(anyhow!("EndsWith requires exactly 1 argument"));
855 }
856
857 let suffix = match self.evaluate(&args[0], row_index)? {
859 DataValue::String(s) => s,
860 DataValue::InternedString(s) => s.to_string(),
861 DataValue::Integer(n) => n.to_string(),
862 DataValue::Float(f) => f.to_string(),
863 _ => return Err(anyhow!("EndsWith argument must be a string")),
864 };
865
866 let str_val = match value {
868 DataValue::String(s) => s.clone(),
869 DataValue::InternedString(s) => s.to_string(),
870 DataValue::Integer(n) => n.to_string(),
871 DataValue::Float(f) => f.to_string(),
872 DataValue::Boolean(b) => b.to_string(),
873 DataValue::DateTime(dt) => dt.clone(),
874 DataValue::Null => return Ok(DataValue::Boolean(false)),
875 };
876
877 let result = str_val.to_lowercase().ends_with(&suffix.to_lowercase());
879 Ok(DataValue::Boolean(result))
880 }
881 _ => Err(anyhow!("Unsupported method: {}", method)),
882 }
883 }
884
885 fn evaluate_case_expression(
887 &mut self,
888 when_branches: &[crate::sql::recursive_parser::WhenBranch],
889 else_branch: &Option<Box<SqlExpression>>,
890 row_index: usize,
891 ) -> Result<DataValue> {
892 debug!(
893 "ArithmeticEvaluator: evaluating CASE expression for row {}",
894 row_index
895 );
896
897 for branch in when_branches {
899 let condition_result = self.evaluate_condition_as_bool(&branch.condition, row_index)?;
901
902 if condition_result {
903 debug!("CASE: WHEN condition matched, evaluating result expression");
904 return self.evaluate(&branch.result, row_index);
905 }
906 }
907
908 if let Some(else_expr) = else_branch {
910 debug!("CASE: No WHEN matched, evaluating ELSE expression");
911 self.evaluate(else_expr, row_index)
912 } else {
913 debug!("CASE: No WHEN matched and no ELSE, returning NULL");
914 Ok(DataValue::Null)
915 }
916 }
917
918 fn evaluate_condition_as_bool(
920 &mut self,
921 expr: &SqlExpression,
922 row_index: usize,
923 ) -> Result<bool> {
924 let value = self.evaluate(expr, row_index)?;
925
926 match value {
927 DataValue::Boolean(b) => Ok(b),
928 DataValue::Integer(i) => Ok(i != 0),
929 DataValue::Float(f) => Ok(f != 0.0),
930 DataValue::Null => Ok(false),
931 DataValue::String(s) => Ok(!s.is_empty()),
932 DataValue::InternedString(s) => Ok(!s.is_empty()),
933 _ => Ok(true), }
935 }
936}
937
938#[cfg(test)]
939mod tests {
940 use super::*;
941 use crate::data::datatable::{DataColumn, DataRow};
942
943 fn create_test_table() -> DataTable {
944 let mut table = DataTable::new("test");
945 table.add_column(DataColumn::new("a"));
946 table.add_column(DataColumn::new("b"));
947 table.add_column(DataColumn::new("c"));
948
949 table
950 .add_row(DataRow::new(vec![
951 DataValue::Integer(10),
952 DataValue::Float(2.5),
953 DataValue::Integer(4),
954 ]))
955 .unwrap();
956
957 table
958 }
959
960 #[test]
961 fn test_evaluate_column() {
962 let table = create_test_table();
963 let mut evaluator = ArithmeticEvaluator::new(&table);
964
965 let expr = SqlExpression::Column("a".to_string());
966 let result = evaluator.evaluate(&expr, 0).unwrap();
967 assert_eq!(result, DataValue::Integer(10));
968 }
969
970 #[test]
971 fn test_evaluate_number_literal() {
972 let table = create_test_table();
973 let mut evaluator = ArithmeticEvaluator::new(&table);
974
975 let expr = SqlExpression::NumberLiteral("42".to_string());
976 let result = evaluator.evaluate(&expr, 0).unwrap();
977 assert_eq!(result, DataValue::Integer(42));
978
979 let expr = SqlExpression::NumberLiteral("3.14".to_string());
980 let result = evaluator.evaluate(&expr, 0).unwrap();
981 assert_eq!(result, DataValue::Float(3.14));
982 }
983
984 #[test]
985 fn test_add_values() {
986 let table = create_test_table();
987 let mut evaluator = ArithmeticEvaluator::new(&table);
988
989 let result = evaluator
991 .add_values(&DataValue::Integer(5), &DataValue::Integer(3))
992 .unwrap();
993 assert_eq!(result, DataValue::Integer(8));
994
995 let result = evaluator
997 .add_values(&DataValue::Integer(5), &DataValue::Float(2.5))
998 .unwrap();
999 assert_eq!(result, DataValue::Float(7.5));
1000 }
1001
1002 #[test]
1003 fn test_multiply_values() {
1004 let table = create_test_table();
1005 let mut evaluator = ArithmeticEvaluator::new(&table);
1006
1007 let result = evaluator
1009 .multiply_values(&DataValue::Integer(4), &DataValue::Float(2.5))
1010 .unwrap();
1011 assert_eq!(result, DataValue::Float(10.0));
1012 }
1013
1014 #[test]
1015 fn test_divide_values() {
1016 let table = create_test_table();
1017 let mut evaluator = ArithmeticEvaluator::new(&table);
1018
1019 let result = evaluator
1021 .divide_values(&DataValue::Integer(10), &DataValue::Integer(2))
1022 .unwrap();
1023 assert_eq!(result, DataValue::Integer(5));
1024
1025 let result = evaluator
1027 .divide_values(&DataValue::Integer(10), &DataValue::Integer(3))
1028 .unwrap();
1029 assert_eq!(result, DataValue::Float(10.0 / 3.0));
1030 }
1031
1032 #[test]
1033 fn test_division_by_zero() {
1034 let table = create_test_table();
1035 let mut evaluator = ArithmeticEvaluator::new(&table);
1036
1037 let result = evaluator.divide_values(&DataValue::Integer(10), &DataValue::Integer(0));
1038 assert!(result.is_err());
1039 assert!(result.unwrap_err().to_string().contains("Division by zero"));
1040 }
1041
1042 #[test]
1043 fn test_binary_op_expression() {
1044 let table = create_test_table();
1045 let mut evaluator = ArithmeticEvaluator::new(&table);
1046
1047 let expr = SqlExpression::BinaryOp {
1049 left: Box::new(SqlExpression::Column("a".to_string())),
1050 op: "*".to_string(),
1051 right: Box::new(SqlExpression::Column("b".to_string())),
1052 };
1053
1054 let result = evaluator.evaluate(&expr, 0).unwrap();
1055 assert_eq!(result, DataValue::Float(25.0));
1056 }
1057}