1use crate::config::global::get_date_notation;
2use crate::data::data_view::DataView;
3use crate::data::datatable::{DataTable, DataValue};
4use crate::data::value_comparisons::compare_with_op;
5use crate::sql::aggregates::AggregateRegistry;
6use crate::sql::functions::FunctionRegistry;
7use crate::sql::parser::ast::WindowSpec;
8use crate::sql::recursive_parser::SqlExpression;
9use crate::sql::window_context::WindowContext;
10use crate::sql::window_functions::{ExpressionEvaluator, WindowFunctionRegistry};
11use anyhow::{anyhow, Result};
12use std::collections::{HashMap, HashSet};
13use std::sync::Arc;
14use tracing::debug;
15
16pub struct ArithmeticEvaluator<'a> {
19 table: &'a DataTable,
20 date_notation: String,
21 function_registry: Arc<FunctionRegistry>,
22 aggregate_registry: Arc<AggregateRegistry>,
23 window_function_registry: Arc<WindowFunctionRegistry>,
24 visible_rows: Option<Vec<usize>>, window_contexts: HashMap<String, Arc<WindowContext>>, table_aliases: HashMap<String, String>, }
28
29impl<'a> ArithmeticEvaluator<'a> {
30 #[must_use]
31 pub fn new(table: &'a DataTable) -> Self {
32 Self {
33 table,
34 date_notation: get_date_notation(),
35 function_registry: Arc::new(FunctionRegistry::new()),
36 aggregate_registry: Arc::new(AggregateRegistry::new()),
37 window_function_registry: Arc::new(WindowFunctionRegistry::new()),
38 visible_rows: None,
39 window_contexts: HashMap::new(),
40 table_aliases: HashMap::new(),
41 }
42 }
43
44 #[must_use]
45 pub fn with_date_notation(table: &'a DataTable, date_notation: String) -> Self {
46 Self {
47 table,
48 date_notation,
49 function_registry: Arc::new(FunctionRegistry::new()),
50 aggregate_registry: Arc::new(AggregateRegistry::new()),
51 window_function_registry: Arc::new(WindowFunctionRegistry::new()),
52 visible_rows: None,
53 window_contexts: HashMap::new(),
54 table_aliases: HashMap::new(),
55 }
56 }
57
58 #[must_use]
60 pub fn with_visible_rows(mut self, rows: Vec<usize>) -> Self {
61 self.visible_rows = Some(rows);
62 self
63 }
64
65 #[must_use]
67 pub fn with_table_aliases(mut self, aliases: HashMap<String, String>) -> Self {
68 self.table_aliases = aliases;
69 self
70 }
71
72 #[must_use]
73 pub fn with_date_notation_and_registry(
74 table: &'a DataTable,
75 date_notation: String,
76 function_registry: Arc<FunctionRegistry>,
77 ) -> Self {
78 Self {
79 table,
80 date_notation,
81 function_registry,
82 aggregate_registry: Arc::new(AggregateRegistry::new()),
83 window_function_registry: Arc::new(WindowFunctionRegistry::new()),
84 visible_rows: None,
85 window_contexts: HashMap::new(),
86 table_aliases: HashMap::new(),
87 }
88 }
89
90 fn find_similar_column(&self, name: &str) -> Option<String> {
92 let columns = self.table.column_names();
93 let mut best_match: Option<(String, usize)> = None;
94
95 for col in columns {
96 let distance = self.edit_distance(&col.to_lowercase(), &name.to_lowercase());
97 let max_distance = if name.len() > 10 { 3 } else { 2 };
100 if distance <= max_distance {
101 match &best_match {
102 None => best_match = Some((col, distance)),
103 Some((_, best_dist)) if distance < *best_dist => {
104 best_match = Some((col, distance));
105 }
106 _ => {}
107 }
108 }
109 }
110
111 best_match.map(|(name, _)| name)
112 }
113
114 fn edit_distance(&self, s1: &str, s2: &str) -> usize {
116 crate::sql::functions::string_methods::EditDistanceFunction::calculate_edit_distance(s1, s2)
118 }
119
120 pub fn evaluate(&mut self, expr: &SqlExpression, row_index: usize) -> Result<DataValue> {
122 debug!(
123 "ArithmeticEvaluator: evaluating {:?} for row {}",
124 expr, row_index
125 );
126
127 match expr {
128 SqlExpression::Column(column_name) => self.evaluate_column(column_name, row_index),
129 SqlExpression::StringLiteral(s) => Ok(DataValue::String(s.clone())),
130 SqlExpression::BooleanLiteral(b) => Ok(DataValue::Boolean(*b)),
131 SqlExpression::NumberLiteral(n) => self.evaluate_number_literal(n),
132 SqlExpression::Null => Ok(DataValue::Null),
133 SqlExpression::BinaryOp { left, op, right } => {
134 self.evaluate_binary_op(left, op, right, row_index)
135 }
136 SqlExpression::FunctionCall {
137 name,
138 args,
139 distinct,
140 } => self.evaluate_function_with_distinct(name, args, *distinct, row_index),
141 SqlExpression::WindowFunction {
142 name,
143 args,
144 window_spec,
145 } => self.evaluate_window_function(name, args, window_spec, row_index),
146 SqlExpression::MethodCall {
147 object,
148 method,
149 args,
150 } => self.evaluate_method_call(object, method, args, row_index),
151 SqlExpression::ChainedMethodCall { base, method, args } => {
152 let base_value = self.evaluate(base, row_index)?;
154 self.evaluate_method_on_value(&base_value, method, args, row_index)
155 }
156 SqlExpression::CaseExpression {
157 when_branches,
158 else_branch,
159 } => self.evaluate_case_expression(when_branches, else_branch, row_index),
160 SqlExpression::SimpleCaseExpression {
161 expr,
162 when_branches,
163 else_branch,
164 } => self.evaluate_simple_case_expression(expr, when_branches, else_branch, row_index),
165 _ => Err(anyhow!(
166 "Unsupported expression type for arithmetic evaluation: {:?}",
167 expr
168 )),
169 }
170 }
171
172 fn evaluate_column(&self, column_name: &str, row_index: usize) -> Result<DataValue> {
174 let resolved_column = if column_name.contains('.') {
176 if let Some(dot_pos) = column_name.rfind('.') {
178 let _table_or_alias = &column_name[..dot_pos];
179 let col_name = &column_name[dot_pos + 1..];
180
181 debug!(
184 "Resolving qualified column: {} -> {}",
185 column_name, col_name
186 );
187 col_name.to_string()
188 } else {
189 column_name.to_string()
190 }
191 } else {
192 column_name.to_string()
193 };
194
195 let col_index = if let Some(idx) = self.table.get_column_index(&resolved_column) {
196 idx
197 } else if resolved_column != column_name {
198 if let Some(idx) = self.table.get_column_index(column_name) {
200 idx
201 } else {
202 let suggestion = self.find_similar_column(&resolved_column);
203 return Err(match suggestion {
204 Some(similar) => anyhow!(
205 "Column '{}' not found. Did you mean '{}'?",
206 column_name,
207 similar
208 ),
209 None => anyhow!("Column '{}' not found", column_name),
210 });
211 }
212 } else {
213 let suggestion = self.find_similar_column(&resolved_column);
214 return Err(match suggestion {
215 Some(similar) => anyhow!(
216 "Column '{}' not found. Did you mean '{}'?",
217 column_name,
218 similar
219 ),
220 None => anyhow!("Column '{}' not found", column_name),
221 });
222 };
223
224 if row_index >= self.table.row_count() {
225 return Err(anyhow!("Row index {} out of bounds", row_index));
226 }
227
228 let row = self
229 .table
230 .get_row(row_index)
231 .ok_or_else(|| anyhow!("Row {} not found", row_index))?;
232
233 let value = row
234 .get(col_index)
235 .ok_or_else(|| anyhow!("Column index {} out of bounds for row", col_index))?;
236
237 Ok(value.clone())
238 }
239
240 fn evaluate_number_literal(&self, number_str: &str) -> Result<DataValue> {
242 if let Ok(int_val) = number_str.parse::<i64>() {
244 return Ok(DataValue::Integer(int_val));
245 }
246
247 if let Ok(float_val) = number_str.parse::<f64>() {
249 return Ok(DataValue::Float(float_val));
250 }
251
252 Err(anyhow!("Invalid number literal: {}", number_str))
253 }
254
255 fn evaluate_binary_op(
257 &mut self,
258 left: &SqlExpression,
259 op: &str,
260 right: &SqlExpression,
261 row_index: usize,
262 ) -> Result<DataValue> {
263 let left_val = self.evaluate(left, row_index)?;
264 let right_val = self.evaluate(right, row_index)?;
265
266 debug!(
267 "ArithmeticEvaluator: {} {} {}",
268 self.format_value(&left_val),
269 op,
270 self.format_value(&right_val)
271 );
272
273 match op {
274 "+" => self.add_values(&left_val, &right_val),
275 "-" => self.subtract_values(&left_val, &right_val),
276 "*" => self.multiply_values(&left_val, &right_val),
277 "/" => self.divide_values(&left_val, &right_val),
278 "%" => {
279 let args = vec![left.clone(), right.clone()];
281 self.evaluate_function("MOD", &args, row_index)
282 }
283 ">" | "<" | ">=" | "<=" | "=" | "!=" | "<>" => {
286 let result = compare_with_op(&left_val, &right_val, op, false);
287 Ok(DataValue::Boolean(result))
288 }
289 "IS NULL" => Ok(DataValue::Boolean(matches!(left_val, DataValue::Null))),
291 "IS NOT NULL" => Ok(DataValue::Boolean(!matches!(left_val, DataValue::Null))),
292 "AND" => {
294 let left_bool = self.to_bool(&left_val)?;
295 let right_bool = self.to_bool(&right_val)?;
296 Ok(DataValue::Boolean(left_bool && right_bool))
297 }
298 "OR" => {
299 let left_bool = self.to_bool(&left_val)?;
300 let right_bool = self.to_bool(&right_val)?;
301 Ok(DataValue::Boolean(left_bool || right_bool))
302 }
303 _ => Err(anyhow!("Unsupported arithmetic operator: {}", op)),
304 }
305 }
306
307 fn add_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
309 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
311 return Ok(DataValue::Null);
312 }
313
314 match (left, right) {
315 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a + b)),
316 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 + b)),
317 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a + *b as f64)),
318 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a + b)),
319 _ => Err(anyhow!("Cannot add {:?} and {:?}", left, right)),
320 }
321 }
322
323 fn subtract_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
325 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
327 return Ok(DataValue::Null);
328 }
329
330 match (left, right) {
331 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a - b)),
332 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 - b)),
333 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a - *b as f64)),
334 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a - b)),
335 _ => Err(anyhow!("Cannot subtract {:?} and {:?}", left, right)),
336 }
337 }
338
339 fn multiply_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
341 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
343 return Ok(DataValue::Null);
344 }
345
346 match (left, right) {
347 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a * b)),
348 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 * b)),
349 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a * *b as f64)),
350 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a * b)),
351 _ => Err(anyhow!("Cannot multiply {:?} and {:?}", left, right)),
352 }
353 }
354
355 fn divide_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
357 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
359 return Ok(DataValue::Null);
360 }
361
362 let is_zero = match right {
364 DataValue::Integer(0) => true,
365 DataValue::Float(f) if *f == 0.0 => true, _ => false,
367 };
368
369 if is_zero {
370 return Err(anyhow!("Division by zero"));
371 }
372
373 match (left, right) {
374 (DataValue::Integer(a), DataValue::Integer(b)) => {
375 if a % b == 0 {
377 Ok(DataValue::Integer(a / b))
378 } else {
379 Ok(DataValue::Float(*a as f64 / *b as f64))
380 }
381 }
382 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 / b)),
383 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a / *b as f64)),
384 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a / b)),
385 _ => Err(anyhow!("Cannot divide {:?} and {:?}", left, right)),
386 }
387 }
388
389 fn format_value(&self, value: &DataValue) -> String {
391 match value {
392 DataValue::Integer(i) => i.to_string(),
393 DataValue::Float(f) => f.to_string(),
394 DataValue::String(s) => format!("'{s}'"),
395 _ => format!("{value:?}"),
396 }
397 }
398
399 fn to_bool(&self, value: &DataValue) -> Result<bool> {
401 match value {
402 DataValue::Boolean(b) => Ok(*b),
403 DataValue::Integer(i) => Ok(*i != 0),
404 DataValue::Float(f) => Ok(*f != 0.0),
405 DataValue::Null => Ok(false),
406 _ => Err(anyhow!("Cannot convert {:?} to boolean", value)),
407 }
408 }
409
410 fn evaluate_function_with_distinct(
412 &mut self,
413 name: &str,
414 args: &[SqlExpression],
415 distinct: bool,
416 row_index: usize,
417 ) -> Result<DataValue> {
418 if distinct {
420 let name_upper = name.to_uppercase();
421
422 if self.aggregate_registry.is_aggregate(&name_upper) {
424 return self.evaluate_aggregate_with_distinct(&name_upper, args, row_index);
425 } else {
426 return Err(anyhow!(
427 "DISTINCT can only be used with aggregate functions"
428 ));
429 }
430 }
431
432 self.evaluate_function(name, args, row_index)
434 }
435
436 fn evaluate_aggregate_with_distinct(
437 &mut self,
438 name: &str,
439 args: &[SqlExpression],
440 _row_index: usize,
441 ) -> Result<DataValue> {
442 let name_upper = name.to_uppercase();
443
444 if self.aggregate_registry.get(&name_upper).is_some() {
446 let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
448 visible.clone()
449 } else {
450 (0..self.table.rows.len()).collect()
451 };
452
453 if name_upper == "STRING_AGG" && args.len() >= 2 {
455 let mut state = crate::sql::aggregates::AggregateState::StringAgg(
457 if args.len() >= 2 {
459 let separator = self.evaluate(&args[1], 0)?; match separator {
461 DataValue::String(s) => crate::sql::aggregates::StringAggState::new(&s),
462 DataValue::InternedString(s) => {
463 crate::sql::aggregates::StringAggState::new(&s)
464 }
465 _ => crate::sql::aggregates::StringAggState::new(","), }
467 } else {
468 crate::sql::aggregates::StringAggState::new(",")
469 },
470 );
471
472 let mut seen_values = HashSet::new();
475
476 for &row_idx in &rows_to_process {
477 let value = self.evaluate(&args[0], row_idx)?;
478
479 if !seen_values.insert(value.clone()) {
481 continue; }
483
484 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
486 agg_func.accumulate(&mut state, &value)?;
487 }
488
489 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
491 return Ok(agg_func.finalize(state));
492 }
493
494 let mut vals = Vec::new();
497 for &row_idx in &rows_to_process {
498 if !args.is_empty() {
499 let value = self.evaluate(&args[0], row_idx)?;
500 vals.push(value);
501 }
502 }
503
504 let mut seen = HashSet::new();
506 let mut unique_values = Vec::new();
507 for value in vals {
508 if seen.insert(value.clone()) {
509 unique_values.push(value);
510 }
511 }
512
513 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
515 let mut state = agg_func.init();
516
517 for value in &unique_values {
519 agg_func.accumulate(&mut state, value)?;
520 }
521
522 return Ok(agg_func.finalize(state));
523 }
524
525 Err(anyhow!("Unknown aggregate function: {}", name))
526 }
527
528 fn evaluate_function(
529 &mut self,
530 name: &str,
531 args: &[SqlExpression],
532 row_index: usize,
533 ) -> Result<DataValue> {
534 let name_upper = name.to_uppercase();
536
537 if name_upper == "COUNT" && args.len() == 1 {
539 match &args[0] {
540 SqlExpression::Column(col) if col == "*" => {
541 let count = if let Some(ref visible) = self.visible_rows {
543 visible.len() as i64
544 } else {
545 self.table.rows.len() as i64
546 };
547 return Ok(DataValue::Integer(count));
548 }
549 SqlExpression::StringLiteral(s) if s == "*" => {
550 let count = if let Some(ref visible) = self.visible_rows {
552 visible.len() as i64
553 } else {
554 self.table.rows.len() as i64
555 };
556 return Ok(DataValue::Integer(count));
557 }
558 _ => {
559 }
561 }
562 }
563
564 if self.aggregate_registry.get(&name_upper).is_some() {
566 let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
568 visible.clone()
569 } else {
570 (0..self.table.rows.len()).collect()
571 };
572
573 if name_upper == "STRING_AGG" && args.len() >= 2 {
575 let mut state = crate::sql::aggregates::AggregateState::StringAgg(
577 if args.len() >= 2 {
579 let separator = self.evaluate(&args[1], 0)?; match separator {
581 DataValue::String(s) => crate::sql::aggregates::StringAggState::new(&s),
582 DataValue::InternedString(s) => {
583 crate::sql::aggregates::StringAggState::new(&s)
584 }
585 _ => crate::sql::aggregates::StringAggState::new(","), }
587 } else {
588 crate::sql::aggregates::StringAggState::new(",")
589 },
590 );
591
592 for &row_idx in &rows_to_process {
594 let value = self.evaluate(&args[0], row_idx)?;
595 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
597 agg_func.accumulate(&mut state, &value)?;
598 }
599
600 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
602 return Ok(agg_func.finalize(state));
603 }
604
605 let values = if !args.is_empty()
607 && !(args.len() == 1 && matches!(&args[0], SqlExpression::Column(c) if c == "*"))
608 {
609 let mut vals = Vec::new();
611 for &row_idx in &rows_to_process {
612 let value = self.evaluate(&args[0], row_idx)?;
613 vals.push(value);
614 }
615 Some(vals)
616 } else {
617 None
618 };
619
620 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
622 let mut state = agg_func.init();
623
624 if let Some(values) = values {
625 for value in &values {
627 agg_func.accumulate(&mut state, value)?;
628 }
629 } else {
630 for _ in &rows_to_process {
632 agg_func.accumulate(&mut state, &DataValue::Integer(1))?;
633 }
634 }
635
636 return Ok(agg_func.finalize(state));
637 }
638
639 if self.function_registry.get(name).is_some() {
641 let mut evaluated_args = Vec::new();
643 for arg in args {
644 evaluated_args.push(self.evaluate(arg, row_index)?);
645 }
646
647 let func = self.function_registry.get(name).unwrap();
649 return func.evaluate(&evaluated_args);
650 }
651
652 Err(anyhow!("Unknown function: {}", name))
654 }
655
656 fn get_or_create_window_context(&mut self, spec: &WindowSpec) -> Result<Arc<WindowContext>> {
658 let key = format!("{:?}", spec);
660
661 if let Some(context) = self.window_contexts.get(&key) {
662 return Ok(Arc::clone(context));
663 }
664
665 let data_view = if let Some(ref _visible_rows) = self.visible_rows {
667 let view = DataView::new(Arc::new(self.table.clone()));
669 view
672 } else {
673 DataView::new(Arc::new(self.table.clone()))
674 };
675
676 let context = WindowContext::new_with_spec(Arc::new(data_view), spec.clone())?;
678
679 let context = Arc::new(context);
680 self.window_contexts.insert(key, Arc::clone(&context));
681 Ok(context)
682 }
683
684 fn evaluate_window_function(
686 &mut self,
687 name: &str,
688 args: &[SqlExpression],
689 spec: &WindowSpec,
690 row_index: usize,
691 ) -> Result<DataValue> {
692 let name_upper = name.to_uppercase();
693
694 debug!("Looking for window function {} in registry", name_upper);
696 if let Some(window_fn_arc) = self.window_function_registry.get(&name_upper) {
697 debug!("Found window function {} in registry", name_upper);
698
699 let window_fn = window_fn_arc.as_ref();
701
702 window_fn.validate_args(args)?;
704
705 let transformed_spec = window_fn.transform_window_spec(spec, args)?;
707
708 let context = self.get_or_create_window_context(&transformed_spec)?;
710
711 struct EvaluatorAdapter<'a, 'b> {
713 evaluator: &'a mut ArithmeticEvaluator<'b>,
714 row_index: usize,
715 }
716
717 impl<'a, 'b> ExpressionEvaluator for EvaluatorAdapter<'a, 'b> {
718 fn evaluate(
719 &mut self,
720 expr: &SqlExpression,
721 _row_index: usize,
722 ) -> Result<DataValue> {
723 self.evaluator.evaluate(expr, self.row_index)
724 }
725 }
726
727 let mut adapter = EvaluatorAdapter {
728 evaluator: self,
729 row_index,
730 };
731
732 return window_fn.compute(&context, row_index, args, &mut adapter);
734 }
735
736 let context = self.get_or_create_window_context(spec)?;
738
739 match name_upper.as_str() {
740 "LAG" => {
741 if args.is_empty() {
743 return Err(anyhow!("LAG requires at least 1 argument"));
744 }
745
746 let column = match &args[0] {
748 SqlExpression::Column(col) => col.clone(),
749 _ => return Err(anyhow!("LAG first argument must be a column")),
750 };
751
752 let offset = if args.len() > 1 {
754 match self.evaluate(&args[1], row_index)? {
755 DataValue::Integer(i) => i as i32,
756 _ => return Err(anyhow!("LAG offset must be an integer")),
757 }
758 } else {
759 1
760 };
761
762 Ok(context
764 .get_offset_value(row_index, -offset, &column)
765 .unwrap_or(DataValue::Null))
766 }
767 "LEAD" => {
768 if args.is_empty() {
770 return Err(anyhow!("LEAD requires at least 1 argument"));
771 }
772
773 let column = match &args[0] {
775 SqlExpression::Column(col) => col.clone(),
776 _ => return Err(anyhow!("LEAD first argument must be a column")),
777 };
778
779 let offset = if args.len() > 1 {
781 match self.evaluate(&args[1], row_index)? {
782 DataValue::Integer(i) => i as i32,
783 _ => return Err(anyhow!("LEAD offset must be an integer")),
784 }
785 } else {
786 1
787 };
788
789 Ok(context
791 .get_offset_value(row_index, offset, &column)
792 .unwrap_or(DataValue::Null))
793 }
794 "ROW_NUMBER" => {
795 Ok(DataValue::Integer(context.get_row_number(row_index) as i64))
797 }
798 "FIRST_VALUE" => {
799 if args.is_empty() {
801 return Err(anyhow!("FIRST_VALUE requires 1 argument"));
802 }
803
804 let column = match &args[0] {
805 SqlExpression::Column(col) => col.clone(),
806 _ => return Err(anyhow!("FIRST_VALUE argument must be a column")),
807 };
808
809 if context.has_frame() {
811 Ok(context
812 .get_frame_first_value(row_index, &column)
813 .unwrap_or(DataValue::Null))
814 } else {
815 Ok(context
816 .get_first_value(row_index, &column)
817 .unwrap_or(DataValue::Null))
818 }
819 }
820 "LAST_VALUE" => {
821 if args.is_empty() {
823 return Err(anyhow!("LAST_VALUE requires 1 argument"));
824 }
825
826 let column = match &args[0] {
827 SqlExpression::Column(col) => col.clone(),
828 _ => return Err(anyhow!("LAST_VALUE argument must be a column")),
829 };
830
831 if context.has_frame() {
833 Ok(context
834 .get_frame_last_value(row_index, &column)
835 .unwrap_or(DataValue::Null))
836 } else {
837 Ok(context
838 .get_last_value(row_index, &column)
839 .unwrap_or(DataValue::Null))
840 }
841 }
842 "SUM" => {
843 if args.is_empty() {
845 return Err(anyhow!("SUM requires 1 argument"));
846 }
847
848 let column = match &args[0] {
849 SqlExpression::Column(col) => col.clone(),
850 _ => return Err(anyhow!("SUM argument must be a column")),
851 };
852
853 if context.has_frame() {
855 Ok(context
856 .get_frame_sum(row_index, &column)
857 .unwrap_or(DataValue::Null))
858 } else {
859 Ok(context
860 .get_partition_sum(row_index, &column)
861 .unwrap_or(DataValue::Null))
862 }
863 }
864 "AVG" => {
865 if args.is_empty() {
867 return Err(anyhow!("AVG requires 1 argument"));
868 }
869
870 let column = match &args[0] {
871 SqlExpression::Column(col) => col.clone(),
872 _ => return Err(anyhow!("AVG argument must be a column")),
873 };
874
875 Ok(context
876 .get_frame_avg(row_index, &column)
877 .unwrap_or(DataValue::Null))
878 }
879 "STDDEV" | "STDEV" => {
880 if args.is_empty() {
882 return Err(anyhow!("STDDEV requires 1 argument"));
883 }
884
885 let column = match &args[0] {
886 SqlExpression::Column(col) => col.clone(),
887 _ => return Err(anyhow!("STDDEV argument must be a column")),
888 };
889
890 Ok(context
891 .get_frame_stddev(row_index, &column)
892 .unwrap_or(DataValue::Null))
893 }
894 "VARIANCE" | "VAR" => {
895 if args.is_empty() {
897 return Err(anyhow!("VARIANCE requires 1 argument"));
898 }
899
900 let column = match &args[0] {
901 SqlExpression::Column(col) => col.clone(),
902 _ => return Err(anyhow!("VARIANCE argument must be a column")),
903 };
904
905 Ok(context
906 .get_frame_variance(row_index, &column)
907 .unwrap_or(DataValue::Null))
908 }
909 "MIN" => {
910 if args.is_empty() {
912 return Err(anyhow!("MIN requires 1 argument"));
913 }
914
915 let column = match &args[0] {
916 SqlExpression::Column(col) => col.clone(),
917 _ => return Err(anyhow!("MIN argument must be a column")),
918 };
919
920 let frame_rows = context.get_frame_rows(row_index);
921 if frame_rows.is_empty() {
922 return Ok(DataValue::Null);
923 }
924
925 let source_table = context.source();
926 let col_idx = source_table
927 .get_column_index(&column)
928 .ok_or_else(|| anyhow!("Column '{}' not found", column))?;
929
930 let mut min_value: Option<DataValue> = None;
931 for &row_idx in &frame_rows {
932 if let Some(value) = source_table.get_value(row_idx, col_idx) {
933 if !matches!(value, DataValue::Null) {
934 match &min_value {
935 None => min_value = Some(value.clone()),
936 Some(current_min) => {
937 if value < current_min {
938 min_value = Some(value.clone());
939 }
940 }
941 }
942 }
943 }
944 }
945
946 Ok(min_value.unwrap_or(DataValue::Null))
947 }
948 "MAX" => {
949 if args.is_empty() {
951 return Err(anyhow!("MAX requires 1 argument"));
952 }
953
954 let column = match &args[0] {
955 SqlExpression::Column(col) => col.clone(),
956 _ => return Err(anyhow!("MAX argument must be a column")),
957 };
958
959 let frame_rows = context.get_frame_rows(row_index);
960 if frame_rows.is_empty() {
961 return Ok(DataValue::Null);
962 }
963
964 let source_table = context.source();
965 let col_idx = source_table
966 .get_column_index(&column)
967 .ok_or_else(|| anyhow!("Column '{}' not found", column))?;
968
969 let mut max_value: Option<DataValue> = None;
970 for &row_idx in &frame_rows {
971 if let Some(value) = source_table.get_value(row_idx, col_idx) {
972 if !matches!(value, DataValue::Null) {
973 match &max_value {
974 None => max_value = Some(value.clone()),
975 Some(current_max) => {
976 if value > current_max {
977 max_value = Some(value.clone());
978 }
979 }
980 }
981 }
982 }
983 }
984
985 Ok(max_value.unwrap_or(DataValue::Null))
986 }
987 "COUNT" => {
988 if args.is_empty() {
992 if context.has_frame() {
994 Ok(context
995 .get_frame_count(row_index, None)
996 .unwrap_or(DataValue::Null))
997 } else {
998 Ok(context
999 .get_partition_count(row_index, None)
1000 .unwrap_or(DataValue::Null))
1001 }
1002 } else {
1003 let column = match &args[0] {
1005 SqlExpression::Column(col) => {
1006 if col == "*" {
1007 if context.has_frame() {
1009 return Ok(context
1010 .get_frame_count(row_index, None)
1011 .unwrap_or(DataValue::Null));
1012 } else {
1013 return Ok(context
1014 .get_partition_count(row_index, None)
1015 .unwrap_or(DataValue::Null));
1016 }
1017 }
1018 col.clone()
1019 }
1020 SqlExpression::StringLiteral(s) if s == "*" => {
1021 if context.has_frame() {
1023 return Ok(context
1024 .get_frame_count(row_index, None)
1025 .unwrap_or(DataValue::Null));
1026 } else {
1027 return Ok(context
1028 .get_partition_count(row_index, None)
1029 .unwrap_or(DataValue::Null));
1030 }
1031 }
1032 _ => return Err(anyhow!("COUNT argument must be a column or *")),
1033 };
1034
1035 if context.has_frame() {
1037 Ok(context
1038 .get_frame_count(row_index, Some(&column))
1039 .unwrap_or(DataValue::Null))
1040 } else {
1041 Ok(context
1042 .get_partition_count(row_index, Some(&column))
1043 .unwrap_or(DataValue::Null))
1044 }
1045 }
1046 }
1047 _ => Err(anyhow!("Unknown window function: {}", name)),
1048 }
1049 }
1050
1051 fn evaluate_method_call(
1053 &mut self,
1054 object: &str,
1055 method: &str,
1056 args: &[SqlExpression],
1057 row_index: usize,
1058 ) -> Result<DataValue> {
1059 let col_index = self.table.get_column_index(object).ok_or_else(|| {
1061 let suggestion = self.find_similar_column(object);
1062 match suggestion {
1063 Some(similar) => {
1064 anyhow!("Column '{}' not found. Did you mean '{}'?", object, similar)
1065 }
1066 None => anyhow!("Column '{}' not found", object),
1067 }
1068 })?;
1069
1070 let cell_value = self.table.get_value(row_index, col_index).cloned();
1071
1072 self.evaluate_method_on_value(
1073 &cell_value.unwrap_or(DataValue::Null),
1074 method,
1075 args,
1076 row_index,
1077 )
1078 }
1079
1080 fn evaluate_method_on_value(
1082 &mut self,
1083 value: &DataValue,
1084 method: &str,
1085 args: &[SqlExpression],
1086 row_index: usize,
1087 ) -> Result<DataValue> {
1088 let function_name = match method.to_lowercase().as_str() {
1093 "trim" => "TRIM",
1094 "trimstart" | "trimbegin" => "TRIMSTART",
1095 "trimend" => "TRIMEND",
1096 "length" | "len" => "LENGTH",
1097 "contains" => "CONTAINS",
1098 "startswith" => "STARTSWITH",
1099 "endswith" => "ENDSWITH",
1100 "indexof" => "INDEXOF",
1101 _ => method, };
1103
1104 if self.function_registry.get(function_name).is_some() {
1106 debug!(
1107 "Proxying method '{}' through function registry as '{}'",
1108 method, function_name
1109 );
1110
1111 let mut func_args = vec![value.clone()];
1113
1114 for arg in args {
1116 func_args.push(self.evaluate(arg, row_index)?);
1117 }
1118
1119 let func = self.function_registry.get(function_name).unwrap();
1121 return func.evaluate(&func_args);
1122 }
1123
1124 Err(anyhow!(
1127 "Method '{}' not found. It should be registered in the function registry.",
1128 method
1129 ))
1130 }
1131
1132 fn evaluate_case_expression(
1134 &mut self,
1135 when_branches: &[crate::sql::recursive_parser::WhenBranch],
1136 else_branch: &Option<Box<SqlExpression>>,
1137 row_index: usize,
1138 ) -> Result<DataValue> {
1139 debug!(
1140 "ArithmeticEvaluator: evaluating CASE expression for row {}",
1141 row_index
1142 );
1143
1144 for branch in when_branches {
1146 let condition_result = self.evaluate_condition_as_bool(&branch.condition, row_index)?;
1148
1149 if condition_result {
1150 debug!("CASE: WHEN condition matched, evaluating result expression");
1151 return self.evaluate(&branch.result, row_index);
1152 }
1153 }
1154
1155 if let Some(else_expr) = else_branch {
1157 debug!("CASE: No WHEN matched, evaluating ELSE expression");
1158 self.evaluate(else_expr, row_index)
1159 } else {
1160 debug!("CASE: No WHEN matched and no ELSE, returning NULL");
1161 Ok(DataValue::Null)
1162 }
1163 }
1164
1165 fn evaluate_simple_case_expression(
1167 &mut self,
1168 expr: &Box<SqlExpression>,
1169 when_branches: &[crate::sql::parser::ast::SimpleWhenBranch],
1170 else_branch: &Option<Box<SqlExpression>>,
1171 row_index: usize,
1172 ) -> Result<DataValue> {
1173 debug!(
1174 "ArithmeticEvaluator: evaluating simple CASE expression for row {}",
1175 row_index
1176 );
1177
1178 let case_value = self.evaluate(expr, row_index)?;
1180 debug!("Simple CASE: evaluated expression to {:?}", case_value);
1181
1182 for branch in when_branches {
1184 let when_value = self.evaluate(&branch.value, row_index)?;
1186
1187 if self.values_equal(&case_value, &when_value)? {
1189 debug!("Simple CASE: WHEN value matched, evaluating result expression");
1190 return self.evaluate(&branch.result, row_index);
1191 }
1192 }
1193
1194 if let Some(else_expr) = else_branch {
1196 debug!("Simple CASE: No WHEN matched, evaluating ELSE expression");
1197 self.evaluate(else_expr, row_index)
1198 } else {
1199 debug!("Simple CASE: No WHEN matched and no ELSE, returning NULL");
1200 Ok(DataValue::Null)
1201 }
1202 }
1203
1204 fn values_equal(&self, left: &DataValue, right: &DataValue) -> Result<bool> {
1206 match (left, right) {
1207 (DataValue::Null, DataValue::Null) => Ok(true),
1208 (DataValue::Null, _) | (_, DataValue::Null) => Ok(false),
1209 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(a == b),
1210 (DataValue::Float(a), DataValue::Float(b)) => Ok((a - b).abs() < f64::EPSILON),
1211 (DataValue::String(a), DataValue::String(b)) => Ok(a == b),
1212 (DataValue::Boolean(a), DataValue::Boolean(b)) => Ok(a == b),
1213 (DataValue::DateTime(a), DataValue::DateTime(b)) => Ok(a == b),
1214 (DataValue::Integer(a), DataValue::Float(b)) => {
1216 Ok((*a as f64 - b).abs() < f64::EPSILON)
1217 }
1218 (DataValue::Float(a), DataValue::Integer(b)) => {
1219 Ok((a - *b as f64).abs() < f64::EPSILON)
1220 }
1221 _ => Ok(false),
1222 }
1223 }
1224
1225 fn evaluate_condition_as_bool(
1227 &mut self,
1228 expr: &SqlExpression,
1229 row_index: usize,
1230 ) -> Result<bool> {
1231 let value = self.evaluate(expr, row_index)?;
1232
1233 match value {
1234 DataValue::Boolean(b) => Ok(b),
1235 DataValue::Integer(i) => Ok(i != 0),
1236 DataValue::Float(f) => Ok(f != 0.0),
1237 DataValue::Null => Ok(false),
1238 DataValue::String(s) => Ok(!s.is_empty()),
1239 DataValue::InternedString(s) => Ok(!s.is_empty()),
1240 _ => Ok(true), }
1242 }
1243}
1244
1245#[cfg(test)]
1246mod tests {
1247 use super::*;
1248 use crate::data::datatable::{DataColumn, DataRow};
1249
1250 fn create_test_table() -> DataTable {
1251 let mut table = DataTable::new("test");
1252 table.add_column(DataColumn::new("a"));
1253 table.add_column(DataColumn::new("b"));
1254 table.add_column(DataColumn::new("c"));
1255
1256 table
1257 .add_row(DataRow::new(vec![
1258 DataValue::Integer(10),
1259 DataValue::Float(2.5),
1260 DataValue::Integer(4),
1261 ]))
1262 .unwrap();
1263
1264 table
1265 }
1266
1267 #[test]
1268 fn test_evaluate_column() {
1269 let table = create_test_table();
1270 let mut evaluator = ArithmeticEvaluator::new(&table);
1271
1272 let expr = SqlExpression::Column("a".to_string());
1273 let result = evaluator.evaluate(&expr, 0).unwrap();
1274 assert_eq!(result, DataValue::Integer(10));
1275 }
1276
1277 #[test]
1278 fn test_evaluate_number_literal() {
1279 let table = create_test_table();
1280 let mut evaluator = ArithmeticEvaluator::new(&table);
1281
1282 let expr = SqlExpression::NumberLiteral("42".to_string());
1283 let result = evaluator.evaluate(&expr, 0).unwrap();
1284 assert_eq!(result, DataValue::Integer(42));
1285
1286 let expr = SqlExpression::NumberLiteral("3.14".to_string());
1287 let result = evaluator.evaluate(&expr, 0).unwrap();
1288 assert_eq!(result, DataValue::Float(3.14));
1289 }
1290
1291 #[test]
1292 fn test_add_values() {
1293 let table = create_test_table();
1294 let mut evaluator = ArithmeticEvaluator::new(&table);
1295
1296 let result = evaluator
1298 .add_values(&DataValue::Integer(5), &DataValue::Integer(3))
1299 .unwrap();
1300 assert_eq!(result, DataValue::Integer(8));
1301
1302 let result = evaluator
1304 .add_values(&DataValue::Integer(5), &DataValue::Float(2.5))
1305 .unwrap();
1306 assert_eq!(result, DataValue::Float(7.5));
1307 }
1308
1309 #[test]
1310 fn test_multiply_values() {
1311 let table = create_test_table();
1312 let mut evaluator = ArithmeticEvaluator::new(&table);
1313
1314 let result = evaluator
1316 .multiply_values(&DataValue::Integer(4), &DataValue::Float(2.5))
1317 .unwrap();
1318 assert_eq!(result, DataValue::Float(10.0));
1319 }
1320
1321 #[test]
1322 fn test_divide_values() {
1323 let table = create_test_table();
1324 let mut evaluator = ArithmeticEvaluator::new(&table);
1325
1326 let result = evaluator
1328 .divide_values(&DataValue::Integer(10), &DataValue::Integer(2))
1329 .unwrap();
1330 assert_eq!(result, DataValue::Integer(5));
1331
1332 let result = evaluator
1334 .divide_values(&DataValue::Integer(10), &DataValue::Integer(3))
1335 .unwrap();
1336 assert_eq!(result, DataValue::Float(10.0 / 3.0));
1337 }
1338
1339 #[test]
1340 fn test_division_by_zero() {
1341 let table = create_test_table();
1342 let mut evaluator = ArithmeticEvaluator::new(&table);
1343
1344 let result = evaluator.divide_values(&DataValue::Integer(10), &DataValue::Integer(0));
1345 assert!(result.is_err());
1346 assert!(result.unwrap_err().to_string().contains("Division by zero"));
1347 }
1348
1349 #[test]
1350 fn test_binary_op_expression() {
1351 let table = create_test_table();
1352 let mut evaluator = ArithmeticEvaluator::new(&table);
1353
1354 let expr = SqlExpression::BinaryOp {
1356 left: Box::new(SqlExpression::Column("a".to_string())),
1357 op: "*".to_string(),
1358 right: Box::new(SqlExpression::Column("b".to_string())),
1359 };
1360
1361 let result = evaluator.evaluate(&expr, 0).unwrap();
1362 assert_eq!(result, DataValue::Float(25.0));
1363 }
1364}