1use crate::data::data_view::DataView;
2use crate::data::datatable::{DataTable, DataValue};
3use crate::sql::aggregates::AggregateRegistry;
4use crate::sql::functions::FunctionRegistry;
5use crate::sql::recursive_parser::{SqlExpression, WindowSpec};
6use crate::sql::window_context::WindowContext;
7use anyhow::{anyhow, Result};
8use std::collections::HashMap;
9use std::sync::Arc;
10use tracing::debug;
11
12pub struct ArithmeticEvaluator<'a> {
15 table: &'a DataTable,
16 date_notation: String,
17 function_registry: Arc<FunctionRegistry>,
18 aggregate_registry: Arc<AggregateRegistry>,
19 visible_rows: Option<Vec<usize>>, window_contexts: HashMap<String, Arc<WindowContext>>, }
22
23impl<'a> ArithmeticEvaluator<'a> {
24 #[must_use]
25 pub fn new(table: &'a DataTable) -> Self {
26 Self {
27 table,
28 date_notation: "us".to_string(),
29 function_registry: Arc::new(FunctionRegistry::new()),
30 aggregate_registry: Arc::new(AggregateRegistry::new()),
31 visible_rows: None,
32 window_contexts: HashMap::new(),
33 }
34 }
35
36 #[must_use]
37 pub fn with_date_notation(table: &'a DataTable, date_notation: String) -> Self {
38 Self {
39 table,
40 date_notation,
41 function_registry: Arc::new(FunctionRegistry::new()),
42 aggregate_registry: Arc::new(AggregateRegistry::new()),
43 visible_rows: None,
44 window_contexts: HashMap::new(),
45 }
46 }
47
48 #[must_use]
50 pub fn with_visible_rows(mut self, rows: Vec<usize>) -> Self {
51 self.visible_rows = Some(rows);
52 self
53 }
54
55 #[must_use]
56 pub fn with_date_notation_and_registry(
57 table: &'a DataTable,
58 date_notation: String,
59 function_registry: Arc<FunctionRegistry>,
60 ) -> Self {
61 Self {
62 table,
63 date_notation,
64 function_registry,
65 aggregate_registry: Arc::new(AggregateRegistry::new()),
66 visible_rows: None,
67 window_contexts: HashMap::new(),
68 }
69 }
70
71 fn find_similar_column(&self, name: &str) -> Option<String> {
73 let columns = self.table.column_names();
74 let mut best_match: Option<(String, usize)> = None;
75
76 for col in columns {
77 let distance = self.edit_distance(&col.to_lowercase(), &name.to_lowercase());
78 let max_distance = if name.len() > 10 { 3 } else { 2 };
81 if distance <= max_distance {
82 match &best_match {
83 None => best_match = Some((col, distance)),
84 Some((_, best_dist)) if distance < *best_dist => {
85 best_match = Some((col, distance));
86 }
87 _ => {}
88 }
89 }
90 }
91
92 best_match.map(|(name, _)| name)
93 }
94
95 fn edit_distance(&self, s1: &str, s2: &str) -> usize {
97 crate::sql::functions::string_methods::EditDistanceFunction::calculate_edit_distance(s1, s2)
99 }
100
101 pub fn evaluate(&mut self, expr: &SqlExpression, row_index: usize) -> Result<DataValue> {
103 debug!(
104 "ArithmeticEvaluator: evaluating {:?} for row {}",
105 expr, row_index
106 );
107
108 match expr {
109 SqlExpression::Column(column_name) => self.evaluate_column(column_name, row_index),
110 SqlExpression::StringLiteral(s) => Ok(DataValue::String(s.clone())),
111 SqlExpression::BooleanLiteral(b) => Ok(DataValue::Boolean(*b)),
112 SqlExpression::NumberLiteral(n) => self.evaluate_number_literal(n),
113 SqlExpression::Null => Ok(DataValue::Null),
114 SqlExpression::BinaryOp { left, op, right } => {
115 self.evaluate_binary_op(left, op, right, row_index)
116 }
117 SqlExpression::FunctionCall {
118 name,
119 args,
120 distinct,
121 } => self.evaluate_function_with_distinct(name, args, *distinct, row_index),
122 SqlExpression::WindowFunction {
123 name,
124 args,
125 window_spec,
126 } => self.evaluate_window_function(name, args, window_spec, row_index),
127 SqlExpression::MethodCall {
128 object,
129 method,
130 args,
131 } => self.evaluate_method_call(object, method, args, row_index),
132 SqlExpression::ChainedMethodCall { base, method, args } => {
133 let base_value = self.evaluate(base, row_index)?;
135 self.evaluate_method_on_value(&base_value, method, args, row_index)
136 }
137 SqlExpression::CaseExpression {
138 when_branches,
139 else_branch,
140 } => self.evaluate_case_expression(when_branches, else_branch, row_index),
141 _ => Err(anyhow!(
142 "Unsupported expression type for arithmetic evaluation: {:?}",
143 expr
144 )),
145 }
146 }
147
148 fn evaluate_column(&self, column_name: &str, row_index: usize) -> Result<DataValue> {
150 let col_index = self.table.get_column_index(column_name).ok_or_else(|| {
151 let suggestion = self.find_similar_column(column_name);
152 match suggestion {
153 Some(similar) => anyhow!(
154 "Column '{}' not found. Did you mean '{}'?",
155 column_name,
156 similar
157 ),
158 None => anyhow!("Column '{}' not found", column_name),
159 }
160 })?;
161
162 if row_index >= self.table.row_count() {
163 return Err(anyhow!("Row index {} out of bounds", row_index));
164 }
165
166 let row = self
167 .table
168 .get_row(row_index)
169 .ok_or_else(|| anyhow!("Row {} not found", row_index))?;
170
171 let value = row
172 .get(col_index)
173 .ok_or_else(|| anyhow!("Column index {} out of bounds for row", col_index))?;
174
175 Ok(value.clone())
176 }
177
178 fn evaluate_number_literal(&self, number_str: &str) -> Result<DataValue> {
180 if let Ok(int_val) = number_str.parse::<i64>() {
182 return Ok(DataValue::Integer(int_val));
183 }
184
185 if let Ok(float_val) = number_str.parse::<f64>() {
187 return Ok(DataValue::Float(float_val));
188 }
189
190 Err(anyhow!("Invalid number literal: {}", number_str))
191 }
192
193 fn evaluate_binary_op(
195 &mut self,
196 left: &SqlExpression,
197 op: &str,
198 right: &SqlExpression,
199 row_index: usize,
200 ) -> Result<DataValue> {
201 let left_val = self.evaluate(left, row_index)?;
202 let right_val = self.evaluate(right, row_index)?;
203
204 debug!(
205 "ArithmeticEvaluator: {} {} {}",
206 self.format_value(&left_val),
207 op,
208 self.format_value(&right_val)
209 );
210
211 match op {
212 "+" => self.add_values(&left_val, &right_val),
213 "-" => self.subtract_values(&left_val, &right_val),
214 "*" => self.multiply_values(&left_val, &right_val),
215 "/" => self.divide_values(&left_val, &right_val),
216 "%" => {
217 let args = vec![left.clone(), right.clone()];
219 self.evaluate_function("MOD", &args, row_index)
220 }
221 ">" => self.compare_values(&left_val, &right_val, |a, b| a > b),
223 "<" => self.compare_values(&left_val, &right_val, |a, b| a < b),
224 ">=" => self.compare_values(&left_val, &right_val, |a, b| a >= b),
225 "<=" => self.compare_values(&left_val, &right_val, |a, b| a <= b),
226 "=" => self.compare_values(&left_val, &right_val, |a, b| a == b),
227 "!=" | "<>" => self.compare_values(&left_val, &right_val, |a, b| a != b),
228 "IS NULL" => Ok(DataValue::Boolean(matches!(left_val, DataValue::Null))),
230 "IS NOT NULL" => Ok(DataValue::Boolean(!matches!(left_val, DataValue::Null))),
231 _ => Err(anyhow!("Unsupported arithmetic operator: {}", op)),
232 }
233 }
234
235 fn add_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
237 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
239 return Ok(DataValue::Null);
240 }
241
242 match (left, right) {
243 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a + b)),
244 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 + b)),
245 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a + *b as f64)),
246 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a + b)),
247 _ => Err(anyhow!("Cannot add {:?} and {:?}", left, right)),
248 }
249 }
250
251 fn subtract_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
253 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
255 return Ok(DataValue::Null);
256 }
257
258 match (left, right) {
259 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a - b)),
260 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 - b)),
261 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a - *b as f64)),
262 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a - b)),
263 _ => Err(anyhow!("Cannot subtract {:?} and {:?}", left, right)),
264 }
265 }
266
267 fn multiply_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
269 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
271 return Ok(DataValue::Null);
272 }
273
274 match (left, right) {
275 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a * b)),
276 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 * b)),
277 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a * *b as f64)),
278 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a * b)),
279 _ => Err(anyhow!("Cannot multiply {:?} and {:?}", left, right)),
280 }
281 }
282
283 fn divide_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
285 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
287 return Ok(DataValue::Null);
288 }
289
290 let is_zero = match right {
292 DataValue::Integer(0) => true,
293 DataValue::Float(f) if *f == 0.0 => true, _ => false,
295 };
296
297 if is_zero {
298 return Err(anyhow!("Division by zero"));
299 }
300
301 match (left, right) {
302 (DataValue::Integer(a), DataValue::Integer(b)) => {
303 if a % b == 0 {
305 Ok(DataValue::Integer(a / b))
306 } else {
307 Ok(DataValue::Float(*a as f64 / *b as f64))
308 }
309 }
310 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 / b)),
311 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a / *b as f64)),
312 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a / b)),
313 _ => Err(anyhow!("Cannot divide {:?} and {:?}", left, right)),
314 }
315 }
316
317 fn format_value(&self, value: &DataValue) -> String {
319 match value {
320 DataValue::Integer(i) => i.to_string(),
321 DataValue::Float(f) => f.to_string(),
322 DataValue::String(s) => format!("'{s}'"),
323 _ => format!("{value:?}"),
324 }
325 }
326
327 fn compare_values<F>(&self, left: &DataValue, right: &DataValue, op: F) -> Result<DataValue>
329 where
330 F: Fn(f64, f64) -> bool,
331 {
332 debug!(
333 "ArithmeticEvaluator: comparing values {:?} and {:?}",
334 left, right
335 );
336
337 let result = match (left, right) {
338 (DataValue::Integer(a), DataValue::Integer(b)) => op(*a as f64, *b as f64),
340 (DataValue::Integer(a), DataValue::Float(b)) => op(*a as f64, *b),
341 (DataValue::Float(a), DataValue::Integer(b)) => op(*a, *b as f64),
342 (DataValue::Float(a), DataValue::Float(b)) => op(*a, *b),
343
344 (DataValue::String(a), DataValue::String(b)) => {
346 let a_num = a.parse::<f64>();
347 let b_num = b.parse::<f64>();
348 match (a_num, b_num) {
349 (Ok(a_val), Ok(b_val)) => op(a_val, b_val), _ => op(a.len() as f64, b.len() as f64), }
352 }
353 (DataValue::InternedString(a), DataValue::InternedString(b)) => {
354 let a_num = a.parse::<f64>();
355 let b_num = b.parse::<f64>();
356 match (a_num, b_num) {
357 (Ok(a_val), Ok(b_val)) => op(a_val, b_val), _ => op(a.len() as f64, b.len() as f64), }
360 }
361 (DataValue::String(a), DataValue::InternedString(b)) => {
362 let a_num = a.parse::<f64>();
363 let b_num = b.parse::<f64>();
364 match (a_num, b_num) {
365 (Ok(a_val), Ok(b_val)) => op(a_val, b_val), _ => op(a.len() as f64, b.len() as f64), }
368 }
369 (DataValue::InternedString(a), DataValue::String(b)) => {
370 let a_num = a.parse::<f64>();
371 let b_num = b.parse::<f64>();
372 match (a_num, b_num) {
373 (Ok(a_val), Ok(b_val)) => op(a_val, b_val), _ => op(a.len() as f64, b.len() as f64), }
376 }
377
378 (DataValue::String(a), DataValue::Integer(b)) => {
380 match a.parse::<f64>() {
381 Ok(a_val) => op(a_val, *b as f64),
382 Err(_) => false, }
384 }
385 (DataValue::Integer(a), DataValue::String(b)) => {
386 match b.parse::<f64>() {
387 Ok(b_val) => op(*a as f64, b_val),
388 Err(_) => false, }
390 }
391 (DataValue::String(a), DataValue::Float(b)) => match a.parse::<f64>() {
392 Ok(a_val) => op(a_val, *b),
393 Err(_) => false,
394 },
395 (DataValue::Float(a), DataValue::String(b)) => match b.parse::<f64>() {
396 Ok(b_val) => op(*a, b_val),
397 Err(_) => false,
398 },
399
400 (DataValue::Null, _) | (_, DataValue::Null) => false,
402
403 (DataValue::Boolean(a), DataValue::Boolean(b)) => {
405 op(if *a { 1.0 } else { 0.0 }, if *b { 1.0 } else { 0.0 })
406 }
407
408 _ => {
409 debug!(
410 "ArithmeticEvaluator: unsupported comparison between {:?} and {:?}",
411 left, right
412 );
413 false
414 }
415 };
416
417 debug!("ArithmeticEvaluator: comparison result: {}", result);
418 Ok(DataValue::Boolean(result))
419 }
420
421 fn evaluate_function_with_distinct(
423 &mut self,
424 name: &str,
425 args: &[SqlExpression],
426 distinct: bool,
427 row_index: usize,
428 ) -> Result<DataValue> {
429 if distinct {
431 let name_upper = name.to_uppercase();
432
433 if name_upper == "COUNT"
435 || name_upper == "SUM"
436 || name_upper == "AVG"
437 || name_upper == "MIN"
438 || name_upper == "MAX"
439 {
440 return self.evaluate_aggregate_distinct(&name_upper, args, row_index);
441 } else {
442 return Err(anyhow!(
443 "DISTINCT can only be used with aggregate functions"
444 ));
445 }
446 }
447
448 self.evaluate_function(name, args, row_index)
450 }
451
452 fn evaluate_aggregate_distinct(
453 &mut self,
454 name: &str,
455 args: &[SqlExpression],
456 row_index: usize,
457 ) -> Result<DataValue> {
458 use std::collections::HashSet;
459
460 if args.is_empty() {
461 return Err(anyhow!("{} DISTINCT requires at least one argument", name));
462 }
463
464 let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
466 visible.clone()
467 } else {
468 (0..self.table.rows.len()).collect()
469 };
470
471 let mut unique_values = HashSet::new();
473 let mut numeric_values = Vec::new();
474
475 for row_idx in &rows_to_process {
476 let value = self.evaluate(&args[0], *row_idx)?;
478
479 if matches!(value, DataValue::Null) {
481 continue;
482 }
483
484 let value_str = match &value {
486 DataValue::String(s) => s.clone(),
487 DataValue::InternedString(s) => s.to_string(),
488 DataValue::Integer(i) => i.to_string(),
489 DataValue::Float(f) => f.to_string(),
490 DataValue::Boolean(b) => b.to_string(),
491 DataValue::DateTime(dt) => dt.to_string(),
492 DataValue::Null => continue,
493 };
494
495 if unique_values.insert(value_str) {
497 if name != "COUNT" {
499 match value {
500 DataValue::Integer(i) => numeric_values.push(i as f64),
501 DataValue::Float(f) => numeric_values.push(f),
502 _ => {} }
504 }
505 }
506 }
507
508 match name {
510 "COUNT" => Ok(DataValue::Integer(unique_values.len() as i64)),
511 "SUM" => {
512 if numeric_values.is_empty() {
513 Ok(DataValue::Null)
514 } else {
515 let sum: f64 = numeric_values.iter().sum();
516 if sum.fract() == 0.0 && sum.abs() < 1e10 {
517 Ok(DataValue::Integer(sum as i64))
518 } else {
519 Ok(DataValue::Float(sum))
520 }
521 }
522 }
523 "AVG" => {
524 if numeric_values.is_empty() {
525 Ok(DataValue::Null)
526 } else {
527 let sum: f64 = numeric_values.iter().sum();
528 Ok(DataValue::Float(sum / numeric_values.len() as f64))
529 }
530 }
531 "MIN" => {
532 if numeric_values.is_empty() {
533 Ok(DataValue::Null)
534 } else {
535 let min = numeric_values.iter().fold(f64::INFINITY, |a, &b| a.min(b));
536 if min.fract() == 0.0 && min.abs() < 1e10 {
537 Ok(DataValue::Integer(min as i64))
538 } else {
539 Ok(DataValue::Float(min))
540 }
541 }
542 }
543 "MAX" => {
544 if numeric_values.is_empty() {
545 Ok(DataValue::Null)
546 } else {
547 let max = numeric_values
548 .iter()
549 .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
550 if max.fract() == 0.0 && max.abs() < 1e10 {
551 Ok(DataValue::Integer(max as i64))
552 } else {
553 Ok(DataValue::Float(max))
554 }
555 }
556 }
557 _ => Err(anyhow!("Unsupported DISTINCT aggregate: {}", name)),
558 }
559 }
560
561 fn evaluate_function(
562 &mut self,
563 name: &str,
564 args: &[SqlExpression],
565 row_index: usize,
566 ) -> Result<DataValue> {
567 let name_upper = name.to_uppercase();
569
570 if name_upper == "COUNT" && args.len() == 1 {
572 match &args[0] {
573 SqlExpression::Column(col) if col == "*" => {
574 let count = if let Some(ref visible) = self.visible_rows {
576 visible.len() as i64
577 } else {
578 self.table.rows.len() as i64
579 };
580 return Ok(DataValue::Integer(count));
581 }
582 SqlExpression::StringLiteral(s) if s == "*" => {
583 let count = if let Some(ref visible) = self.visible_rows {
585 visible.len() as i64
586 } else {
587 self.table.rows.len() as i64
588 };
589 return Ok(DataValue::Integer(count));
590 }
591 _ => {
592 }
594 }
595 }
596
597 if self.aggregate_registry.get(&name_upper).is_some() {
599 let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
601 visible.clone()
602 } else {
603 (0..self.table.rows.len()).collect()
604 };
605
606 let values = if !args.is_empty()
608 && !(args.len() == 1 && matches!(&args[0], SqlExpression::Column(c) if c == "*"))
609 {
610 let mut vals = Vec::new();
612 for &row_idx in &rows_to_process {
613 let value = self.evaluate(&args[0], row_idx)?;
614 vals.push(value);
615 }
616 Some(vals)
617 } else {
618 None
619 };
620
621 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
623 let mut state = agg_func.init();
624
625 if let Some(values) = values {
626 for value in &values {
628 agg_func.accumulate(&mut state, value)?;
629 }
630 } else {
631 for _ in &rows_to_process {
633 agg_func.accumulate(&mut state, &DataValue::Integer(1))?;
634 }
635 }
636
637 return Ok(agg_func.finalize(state));
638 }
639
640 if self.function_registry.get(name).is_some() {
642 let mut evaluated_args = Vec::new();
644 for arg in args {
645 evaluated_args.push(self.evaluate(arg, row_index)?);
646 }
647
648 let func = self.function_registry.get(name).unwrap();
650 return func.evaluate(&evaluated_args);
651 }
652
653 Err(anyhow!("Unknown function: {}", name))
655 }
656
657 fn get_or_create_window_context(&mut self, spec: &WindowSpec) -> Result<Arc<WindowContext>> {
659 let key = format!("{:?}", spec);
661
662 if let Some(context) = self.window_contexts.get(&key) {
663 return Ok(Arc::clone(context));
664 }
665
666 let data_view = if let Some(ref _visible_rows) = self.visible_rows {
668 let view = DataView::new(Arc::new(self.table.clone()));
670 view
673 } else {
674 DataView::new(Arc::new(self.table.clone()))
675 };
676
677 let context = WindowContext::new(
679 Arc::new(data_view),
680 spec.partition_by.clone(),
681 spec.order_by.clone(),
682 )?;
683
684 let context = Arc::new(context);
685 self.window_contexts.insert(key, Arc::clone(&context));
686 Ok(context)
687 }
688
689 fn evaluate_window_function(
691 &mut self,
692 name: &str,
693 args: &[SqlExpression],
694 spec: &WindowSpec,
695 row_index: usize,
696 ) -> Result<DataValue> {
697 let context = self.get_or_create_window_context(spec)?;
698 let name_upper = name.to_uppercase();
699
700 match name_upper.as_str() {
701 "LAG" => {
702 if args.is_empty() {
704 return Err(anyhow!("LAG requires at least 1 argument"));
705 }
706
707 let column = match &args[0] {
709 SqlExpression::Column(col) => col.clone(),
710 _ => return Err(anyhow!("LAG first argument must be a column")),
711 };
712
713 let offset = if args.len() > 1 {
715 match self.evaluate(&args[1], row_index)? {
716 DataValue::Integer(i) => i as i32,
717 _ => return Err(anyhow!("LAG offset must be an integer")),
718 }
719 } else {
720 1
721 };
722
723 Ok(context
725 .get_offset_value(row_index, -offset, &column)
726 .unwrap_or(DataValue::Null))
727 }
728 "LEAD" => {
729 if args.is_empty() {
731 return Err(anyhow!("LEAD requires at least 1 argument"));
732 }
733
734 let column = match &args[0] {
736 SqlExpression::Column(col) => col.clone(),
737 _ => return Err(anyhow!("LEAD first argument must be a column")),
738 };
739
740 let offset = if args.len() > 1 {
742 match self.evaluate(&args[1], row_index)? {
743 DataValue::Integer(i) => i as i32,
744 _ => return Err(anyhow!("LEAD offset must be an integer")),
745 }
746 } else {
747 1
748 };
749
750 Ok(context
752 .get_offset_value(row_index, offset, &column)
753 .unwrap_or(DataValue::Null))
754 }
755 "ROW_NUMBER" => {
756 Ok(DataValue::Integer(context.get_row_number(row_index) as i64))
758 }
759 "FIRST_VALUE" => {
760 if args.is_empty() {
762 return Err(anyhow!("FIRST_VALUE requires 1 argument"));
763 }
764
765 let column = match &args[0] {
766 SqlExpression::Column(col) => col.clone(),
767 _ => return Err(anyhow!("FIRST_VALUE argument must be a column")),
768 };
769
770 Ok(context
771 .get_first_value(row_index, &column)
772 .unwrap_or(DataValue::Null))
773 }
774 "LAST_VALUE" => {
775 if args.is_empty() {
777 return Err(anyhow!("LAST_VALUE requires 1 argument"));
778 }
779
780 let column = match &args[0] {
781 SqlExpression::Column(col) => col.clone(),
782 _ => return Err(anyhow!("LAST_VALUE argument must be a column")),
783 };
784
785 Ok(context
786 .get_last_value(row_index, &column)
787 .unwrap_or(DataValue::Null))
788 }
789 "SUM" => {
790 if args.is_empty() {
792 return Err(anyhow!("SUM requires 1 argument"));
793 }
794
795 let column = match &args[0] {
796 SqlExpression::Column(col) => col.clone(),
797 _ => return Err(anyhow!("SUM argument must be a column")),
798 };
799
800 Ok(context
801 .get_partition_sum(row_index, &column)
802 .unwrap_or(DataValue::Null))
803 }
804 "COUNT" => {
805 if args.is_empty() {
808 Ok(context
810 .get_partition_count(row_index, None)
811 .unwrap_or(DataValue::Null))
812 } else {
813 let column = match &args[0] {
815 SqlExpression::Column(col) => {
816 if col == "*" {
817 return Ok(context
819 .get_partition_count(row_index, None)
820 .unwrap_or(DataValue::Null));
821 }
822 col.clone()
823 }
824 SqlExpression::StringLiteral(s) if s == "*" => {
825 return Ok(context
827 .get_partition_count(row_index, None)
828 .unwrap_or(DataValue::Null));
829 }
830 _ => return Err(anyhow!("COUNT argument must be a column or *")),
831 };
832
833 Ok(context
835 .get_partition_count(row_index, Some(&column))
836 .unwrap_or(DataValue::Null))
837 }
838 }
839 _ => Err(anyhow!("Unknown window function: {}", name)),
840 }
841 }
842
843 fn evaluate_method_call(
845 &mut self,
846 object: &str,
847 method: &str,
848 args: &[SqlExpression],
849 row_index: usize,
850 ) -> Result<DataValue> {
851 let col_index = self.table.get_column_index(object).ok_or_else(|| {
853 let suggestion = self.find_similar_column(object);
854 match suggestion {
855 Some(similar) => {
856 anyhow!("Column '{}' not found. Did you mean '{}'?", object, similar)
857 }
858 None => anyhow!("Column '{}' not found", object),
859 }
860 })?;
861
862 let cell_value = self.table.get_value(row_index, col_index).cloned();
863
864 self.evaluate_method_on_value(
865 &cell_value.unwrap_or(DataValue::Null),
866 method,
867 args,
868 row_index,
869 )
870 }
871
872 fn evaluate_method_on_value(
874 &mut self,
875 value: &DataValue,
876 method: &str,
877 args: &[SqlExpression],
878 row_index: usize,
879 ) -> Result<DataValue> {
880 let function_name = match method.to_lowercase().as_str() {
885 "trim" => "TRIM",
886 "trimstart" | "trimbegin" => "TRIMSTART",
887 "trimend" => "TRIMEND",
888 "length" | "len" => "LENGTH",
889 "contains" => "CONTAINS",
890 "startswith" => "STARTSWITH",
891 "endswith" => "ENDSWITH",
892 "indexof" => "INDEXOF",
893 _ => method, };
895
896 if self.function_registry.get(function_name).is_some() {
898 debug!(
899 "Proxying method '{}' through function registry as '{}'",
900 method, function_name
901 );
902
903 let mut func_args = vec![value.clone()];
905
906 for arg in args {
908 func_args.push(self.evaluate(arg, row_index)?);
909 }
910
911 let func = self.function_registry.get(function_name).unwrap();
913 return func.evaluate(&func_args);
914 }
915
916 match method.to_lowercase().as_str() {
918 "trim" | "trimstart" | "trimend" => {
919 if !args.is_empty() {
920 return Err(anyhow!("{} takes no arguments", method));
921 }
922
923 let str_val = match value {
925 DataValue::String(s) => s.clone(),
926 DataValue::InternedString(s) => s.to_string(),
927 DataValue::Integer(n) => n.to_string(),
928 DataValue::Float(f) => f.to_string(),
929 DataValue::Boolean(b) => b.to_string(),
930 DataValue::DateTime(dt) => dt.clone(),
931 DataValue::Null => return Ok(DataValue::Null),
932 };
933
934 let result = match method.to_lowercase().as_str() {
935 "trim" => str_val.trim().to_string(),
936 "trimstart" => str_val.trim_start().to_string(),
937 "trimend" => str_val.trim_end().to_string(),
938 _ => unreachable!(),
939 };
940
941 Ok(DataValue::String(result))
942 }
943 "length" => {
944 if !args.is_empty() {
945 return Err(anyhow!("Length takes no arguments"));
946 }
947
948 let len = match value {
950 DataValue::String(s) => s.len(),
951 DataValue::InternedString(s) => s.len(),
952 DataValue::Integer(n) => n.to_string().len(),
953 DataValue::Float(f) => f.to_string().len(),
954 DataValue::Boolean(b) => b.to_string().len(),
955 DataValue::DateTime(dt) => dt.len(),
956 DataValue::Null => return Ok(DataValue::Integer(0)),
957 };
958
959 Ok(DataValue::Integer(len as i64))
960 }
961 "indexof" => {
962 if args.len() != 1 {
963 return Err(anyhow!("IndexOf requires exactly 1 argument"));
964 }
965
966 let search_str = match self.evaluate(&args[0], row_index)? {
968 DataValue::String(s) => s,
969 DataValue::InternedString(s) => s.to_string(),
970 DataValue::Integer(n) => n.to_string(),
971 DataValue::Float(f) => f.to_string(),
972 _ => return Err(anyhow!("IndexOf argument must be a string")),
973 };
974
975 let str_val = match value {
977 DataValue::String(s) => s.clone(),
978 DataValue::InternedString(s) => s.to_string(),
979 DataValue::Integer(n) => n.to_string(),
980 DataValue::Float(f) => f.to_string(),
981 DataValue::Boolean(b) => b.to_string(),
982 DataValue::DateTime(dt) => dt.clone(),
983 DataValue::Null => return Ok(DataValue::Integer(-1)),
984 };
985
986 let index = str_val.find(&search_str).map_or(-1, |i| i as i64);
987
988 Ok(DataValue::Integer(index))
989 }
990 "contains" => {
991 if args.len() != 1 {
992 return Err(anyhow!("Contains requires exactly 1 argument"));
993 }
994
995 let search_str = match self.evaluate(&args[0], row_index)? {
997 DataValue::String(s) => s,
998 DataValue::InternedString(s) => s.to_string(),
999 DataValue::Integer(n) => n.to_string(),
1000 DataValue::Float(f) => f.to_string(),
1001 _ => return Err(anyhow!("Contains argument must be a string")),
1002 };
1003
1004 let str_val = match value {
1006 DataValue::String(s) => s.clone(),
1007 DataValue::InternedString(s) => s.to_string(),
1008 DataValue::Integer(n) => n.to_string(),
1009 DataValue::Float(f) => f.to_string(),
1010 DataValue::Boolean(b) => b.to_string(),
1011 DataValue::DateTime(dt) => dt.clone(),
1012 DataValue::Null => return Ok(DataValue::Boolean(false)),
1013 };
1014
1015 let result = str_val.to_lowercase().contains(&search_str.to_lowercase());
1017 Ok(DataValue::Boolean(result))
1018 }
1019 "startswith" => {
1020 if args.len() != 1 {
1021 return Err(anyhow!("StartsWith requires exactly 1 argument"));
1022 }
1023
1024 let prefix = match self.evaluate(&args[0], row_index)? {
1026 DataValue::String(s) => s,
1027 DataValue::InternedString(s) => s.to_string(),
1028 DataValue::Integer(n) => n.to_string(),
1029 DataValue::Float(f) => f.to_string(),
1030 _ => return Err(anyhow!("StartsWith argument must be a string")),
1031 };
1032
1033 let str_val = match value {
1035 DataValue::String(s) => s.clone(),
1036 DataValue::InternedString(s) => s.to_string(),
1037 DataValue::Integer(n) => n.to_string(),
1038 DataValue::Float(f) => f.to_string(),
1039 DataValue::Boolean(b) => b.to_string(),
1040 DataValue::DateTime(dt) => dt.clone(),
1041 DataValue::Null => return Ok(DataValue::Boolean(false)),
1042 };
1043
1044 let result = str_val.to_lowercase().starts_with(&prefix.to_lowercase());
1046 Ok(DataValue::Boolean(result))
1047 }
1048 "endswith" => {
1049 if args.len() != 1 {
1050 return Err(anyhow!("EndsWith requires exactly 1 argument"));
1051 }
1052
1053 let suffix = match self.evaluate(&args[0], row_index)? {
1055 DataValue::String(s) => s,
1056 DataValue::InternedString(s) => s.to_string(),
1057 DataValue::Integer(n) => n.to_string(),
1058 DataValue::Float(f) => f.to_string(),
1059 _ => return Err(anyhow!("EndsWith argument must be a string")),
1060 };
1061
1062 let str_val = match value {
1064 DataValue::String(s) => s.clone(),
1065 DataValue::InternedString(s) => s.to_string(),
1066 DataValue::Integer(n) => n.to_string(),
1067 DataValue::Float(f) => f.to_string(),
1068 DataValue::Boolean(b) => b.to_string(),
1069 DataValue::DateTime(dt) => dt.clone(),
1070 DataValue::Null => return Ok(DataValue::Boolean(false)),
1071 };
1072
1073 let result = str_val.to_lowercase().ends_with(&suffix.to_lowercase());
1075 Ok(DataValue::Boolean(result))
1076 }
1077 _ => Err(anyhow!("Unsupported method: {}", method)),
1078 }
1079 }
1080
1081 fn evaluate_case_expression(
1083 &mut self,
1084 when_branches: &[crate::sql::recursive_parser::WhenBranch],
1085 else_branch: &Option<Box<SqlExpression>>,
1086 row_index: usize,
1087 ) -> Result<DataValue> {
1088 debug!(
1089 "ArithmeticEvaluator: evaluating CASE expression for row {}",
1090 row_index
1091 );
1092
1093 for branch in when_branches {
1095 let condition_result = self.evaluate_condition_as_bool(&branch.condition, row_index)?;
1097
1098 if condition_result {
1099 debug!("CASE: WHEN condition matched, evaluating result expression");
1100 return self.evaluate(&branch.result, row_index);
1101 }
1102 }
1103
1104 if let Some(else_expr) = else_branch {
1106 debug!("CASE: No WHEN matched, evaluating ELSE expression");
1107 self.evaluate(else_expr, row_index)
1108 } else {
1109 debug!("CASE: No WHEN matched and no ELSE, returning NULL");
1110 Ok(DataValue::Null)
1111 }
1112 }
1113
1114 fn evaluate_condition_as_bool(
1116 &mut self,
1117 expr: &SqlExpression,
1118 row_index: usize,
1119 ) -> Result<bool> {
1120 let value = self.evaluate(expr, row_index)?;
1121
1122 match value {
1123 DataValue::Boolean(b) => Ok(b),
1124 DataValue::Integer(i) => Ok(i != 0),
1125 DataValue::Float(f) => Ok(f != 0.0),
1126 DataValue::Null => Ok(false),
1127 DataValue::String(s) => Ok(!s.is_empty()),
1128 DataValue::InternedString(s) => Ok(!s.is_empty()),
1129 _ => Ok(true), }
1131 }
1132}
1133
1134#[cfg(test)]
1135mod tests {
1136 use super::*;
1137 use crate::data::datatable::{DataColumn, DataRow};
1138
1139 fn create_test_table() -> DataTable {
1140 let mut table = DataTable::new("test");
1141 table.add_column(DataColumn::new("a"));
1142 table.add_column(DataColumn::new("b"));
1143 table.add_column(DataColumn::new("c"));
1144
1145 table
1146 .add_row(DataRow::new(vec![
1147 DataValue::Integer(10),
1148 DataValue::Float(2.5),
1149 DataValue::Integer(4),
1150 ]))
1151 .unwrap();
1152
1153 table
1154 }
1155
1156 #[test]
1157 fn test_evaluate_column() {
1158 let table = create_test_table();
1159 let mut evaluator = ArithmeticEvaluator::new(&table);
1160
1161 let expr = SqlExpression::Column("a".to_string());
1162 let result = evaluator.evaluate(&expr, 0).unwrap();
1163 assert_eq!(result, DataValue::Integer(10));
1164 }
1165
1166 #[test]
1167 fn test_evaluate_number_literal() {
1168 let table = create_test_table();
1169 let mut evaluator = ArithmeticEvaluator::new(&table);
1170
1171 let expr = SqlExpression::NumberLiteral("42".to_string());
1172 let result = evaluator.evaluate(&expr, 0).unwrap();
1173 assert_eq!(result, DataValue::Integer(42));
1174
1175 let expr = SqlExpression::NumberLiteral("3.14".to_string());
1176 let result = evaluator.evaluate(&expr, 0).unwrap();
1177 assert_eq!(result, DataValue::Float(3.14));
1178 }
1179
1180 #[test]
1181 fn test_add_values() {
1182 let table = create_test_table();
1183 let mut evaluator = ArithmeticEvaluator::new(&table);
1184
1185 let result = evaluator
1187 .add_values(&DataValue::Integer(5), &DataValue::Integer(3))
1188 .unwrap();
1189 assert_eq!(result, DataValue::Integer(8));
1190
1191 let result = evaluator
1193 .add_values(&DataValue::Integer(5), &DataValue::Float(2.5))
1194 .unwrap();
1195 assert_eq!(result, DataValue::Float(7.5));
1196 }
1197
1198 #[test]
1199 fn test_multiply_values() {
1200 let table = create_test_table();
1201 let mut evaluator = ArithmeticEvaluator::new(&table);
1202
1203 let result = evaluator
1205 .multiply_values(&DataValue::Integer(4), &DataValue::Float(2.5))
1206 .unwrap();
1207 assert_eq!(result, DataValue::Float(10.0));
1208 }
1209
1210 #[test]
1211 fn test_divide_values() {
1212 let table = create_test_table();
1213 let mut evaluator = ArithmeticEvaluator::new(&table);
1214
1215 let result = evaluator
1217 .divide_values(&DataValue::Integer(10), &DataValue::Integer(2))
1218 .unwrap();
1219 assert_eq!(result, DataValue::Integer(5));
1220
1221 let result = evaluator
1223 .divide_values(&DataValue::Integer(10), &DataValue::Integer(3))
1224 .unwrap();
1225 assert_eq!(result, DataValue::Float(10.0 / 3.0));
1226 }
1227
1228 #[test]
1229 fn test_division_by_zero() {
1230 let table = create_test_table();
1231 let mut evaluator = ArithmeticEvaluator::new(&table);
1232
1233 let result = evaluator.divide_values(&DataValue::Integer(10), &DataValue::Integer(0));
1234 assert!(result.is_err());
1235 assert!(result.unwrap_err().to_string().contains("Division by zero"));
1236 }
1237
1238 #[test]
1239 fn test_binary_op_expression() {
1240 let table = create_test_table();
1241 let mut evaluator = ArithmeticEvaluator::new(&table);
1242
1243 let expr = SqlExpression::BinaryOp {
1245 left: Box::new(SqlExpression::Column("a".to_string())),
1246 op: "*".to_string(),
1247 right: Box::new(SqlExpression::Column("b".to_string())),
1248 };
1249
1250 let result = evaluator.evaluate(&expr, 0).unwrap();
1251 assert_eq!(result, DataValue::Float(25.0));
1252 }
1253}