1use crate::config::global::get_date_notation;
2use crate::data::data_view::DataView;
3use crate::data::datatable::{DataTable, DataValue};
4use crate::data::value_comparisons::compare_with_op;
5use crate::sql::aggregate_functions::AggregateFunctionRegistry; use crate::sql::aggregates::AggregateRegistry; use crate::sql::functions::FunctionRegistry;
8use crate::sql::parser::ast::{ColumnRef, WindowSpec};
9use crate::sql::recursive_parser::SqlExpression;
10use crate::sql::window_context::WindowContext;
11use crate::sql::window_functions::{ExpressionEvaluator, WindowFunctionRegistry};
12use anyhow::{anyhow, Result};
13use std::collections::{HashMap, HashSet};
14use std::sync::Arc;
15use tracing::debug;
16
17pub struct ArithmeticEvaluator<'a> {
20 table: &'a DataTable,
21 _date_notation: String,
22 function_registry: Arc<FunctionRegistry>,
23 aggregate_registry: Arc<AggregateRegistry>, new_aggregate_registry: Arc<AggregateFunctionRegistry>, window_function_registry: Arc<WindowFunctionRegistry>,
26 visible_rows: Option<Vec<usize>>, window_contexts: HashMap<String, Arc<WindowContext>>, table_aliases: HashMap<String, String>, }
30
31impl<'a> ArithmeticEvaluator<'a> {
32 #[must_use]
33 pub fn new(table: &'a DataTable) -> Self {
34 Self {
35 table,
36 _date_notation: get_date_notation(),
37 function_registry: Arc::new(FunctionRegistry::new()),
38 aggregate_registry: Arc::new(AggregateRegistry::new()),
39 new_aggregate_registry: Arc::new(AggregateFunctionRegistry::new()),
40 window_function_registry: Arc::new(WindowFunctionRegistry::new()),
41 visible_rows: None,
42 window_contexts: HashMap::new(),
43 table_aliases: HashMap::new(),
44 }
45 }
46
47 #[must_use]
48 pub fn with_date_notation(table: &'a DataTable, date_notation: String) -> Self {
49 Self {
50 table,
51 _date_notation: date_notation,
52 function_registry: Arc::new(FunctionRegistry::new()),
53 aggregate_registry: Arc::new(AggregateRegistry::new()),
54 new_aggregate_registry: Arc::new(AggregateFunctionRegistry::new()),
55 window_function_registry: Arc::new(WindowFunctionRegistry::new()),
56 visible_rows: None,
57 window_contexts: HashMap::new(),
58 table_aliases: HashMap::new(),
59 }
60 }
61
62 #[must_use]
64 pub fn with_visible_rows(mut self, rows: Vec<usize>) -> Self {
65 self.visible_rows = Some(rows);
66 self
67 }
68
69 #[must_use]
71 pub fn with_table_aliases(mut self, aliases: HashMap<String, String>) -> Self {
72 self.table_aliases = aliases;
73 self
74 }
75
76 #[must_use]
77 pub fn with_date_notation_and_registry(
78 table: &'a DataTable,
79 date_notation: String,
80 function_registry: Arc<FunctionRegistry>,
81 ) -> Self {
82 Self {
83 table,
84 _date_notation: date_notation,
85 function_registry,
86 aggregate_registry: Arc::new(AggregateRegistry::new()),
87 new_aggregate_registry: Arc::new(AggregateFunctionRegistry::new()),
88 window_function_registry: Arc::new(WindowFunctionRegistry::new()),
89 visible_rows: None,
90 window_contexts: HashMap::new(),
91 table_aliases: HashMap::new(),
92 }
93 }
94
95 fn find_similar_column(&self, name: &str) -> Option<String> {
97 let columns = self.table.column_names();
98 let mut best_match: Option<(String, usize)> = None;
99
100 for col in columns {
101 let distance = self.edit_distance(&col.to_lowercase(), &name.to_lowercase());
102 let max_distance = if name.len() > 10 { 3 } else { 2 };
105 if distance <= max_distance {
106 match &best_match {
107 None => best_match = Some((col, distance)),
108 Some((_, best_dist)) if distance < *best_dist => {
109 best_match = Some((col, distance));
110 }
111 _ => {}
112 }
113 }
114 }
115
116 best_match.map(|(name, _)| name)
117 }
118
119 fn edit_distance(&self, s1: &str, s2: &str) -> usize {
121 crate::sql::functions::string_methods::EditDistanceFunction::calculate_edit_distance(s1, s2)
123 }
124
125 pub fn evaluate(&mut self, expr: &SqlExpression, row_index: usize) -> Result<DataValue> {
127 debug!(
128 "ArithmeticEvaluator: evaluating {:?} for row {}",
129 expr, row_index
130 );
131
132 match expr {
133 SqlExpression::Column(column_ref) => self.evaluate_column_ref(column_ref, row_index),
134 SqlExpression::StringLiteral(s) => Ok(DataValue::String(s.clone())),
135 SqlExpression::BooleanLiteral(b) => Ok(DataValue::Boolean(*b)),
136 SqlExpression::NumberLiteral(n) => self.evaluate_number_literal(n),
137 SqlExpression::Null => Ok(DataValue::Null),
138 SqlExpression::BinaryOp { left, op, right } => {
139 self.evaluate_binary_op(left, op, right, row_index)
140 }
141 SqlExpression::FunctionCall {
142 name,
143 args,
144 distinct,
145 } => self.evaluate_function_with_distinct(name, args, *distinct, row_index),
146 SqlExpression::WindowFunction {
147 name,
148 args,
149 window_spec,
150 } => self.evaluate_window_function(name, args, window_spec, row_index),
151 SqlExpression::MethodCall {
152 object,
153 method,
154 args,
155 } => self.evaluate_method_call(object, method, args, row_index),
156 SqlExpression::ChainedMethodCall { base, method, args } => {
157 let base_value = self.evaluate(base, row_index)?;
159 self.evaluate_method_on_value(&base_value, method, args, row_index)
160 }
161 SqlExpression::CaseExpression {
162 when_branches,
163 else_branch,
164 } => self.evaluate_case_expression(when_branches, else_branch, row_index),
165 SqlExpression::SimpleCaseExpression {
166 expr,
167 when_branches,
168 else_branch,
169 } => self.evaluate_simple_case_expression(expr, when_branches, else_branch, row_index),
170 _ => Err(anyhow!(
171 "Unsupported expression type for arithmetic evaluation: {:?}",
172 expr
173 )),
174 }
175 }
176
177 fn evaluate_column_ref(&self, column_ref: &ColumnRef, row_index: usize) -> Result<DataValue> {
179 if let Some(table_prefix) = &column_ref.table_prefix {
180 let qualified_name = format!("{}.{}", table_prefix, column_ref.name);
182
183 if let Some(col_idx) = self.table.find_column_by_qualified_name(&qualified_name) {
184 return self
185 .table
186 .get_value(row_index, col_idx)
187 .ok_or_else(|| anyhow!("Row {} out of bounds", row_index))
188 .map(|v| v.clone());
189 }
190
191 Err(anyhow!("Column '{}' not found", qualified_name))
193 } else {
194 self.evaluate_column(&column_ref.name, row_index)
196 }
197 }
198
199 fn evaluate_column(&self, column_name: &str, row_index: usize) -> Result<DataValue> {
201 let resolved_column = if column_name.contains('.') {
203 if let Some(dot_pos) = column_name.rfind('.') {
205 let _table_or_alias = &column_name[..dot_pos];
206 let col_name = &column_name[dot_pos + 1..];
207
208 debug!(
211 "Resolving qualified column: {} -> {}",
212 column_name, col_name
213 );
214 col_name.to_string()
215 } else {
216 column_name.to_string()
217 }
218 } else {
219 column_name.to_string()
220 };
221
222 let col_index = if let Some(idx) = self.table.get_column_index(&resolved_column) {
223 idx
224 } else if resolved_column != column_name {
225 if let Some(idx) = self.table.get_column_index(column_name) {
227 idx
228 } else {
229 let suggestion = self.find_similar_column(&resolved_column);
230 return Err(match suggestion {
231 Some(similar) => anyhow!(
232 "Column '{}' not found. Did you mean '{}'?",
233 column_name,
234 similar
235 ),
236 None => anyhow!("Column '{}' not found", column_name),
237 });
238 }
239 } else {
240 let suggestion = self.find_similar_column(&resolved_column);
241 return Err(match suggestion {
242 Some(similar) => anyhow!(
243 "Column '{}' not found. Did you mean '{}'?",
244 column_name,
245 similar
246 ),
247 None => anyhow!("Column '{}' not found", column_name),
248 });
249 };
250
251 if row_index >= self.table.row_count() {
252 return Err(anyhow!("Row index {} out of bounds", row_index));
253 }
254
255 let row = self
256 .table
257 .get_row(row_index)
258 .ok_or_else(|| anyhow!("Row {} not found", row_index))?;
259
260 let value = row
261 .get(col_index)
262 .ok_or_else(|| anyhow!("Column index {} out of bounds for row", col_index))?;
263
264 Ok(value.clone())
265 }
266
267 fn evaluate_number_literal(&self, number_str: &str) -> Result<DataValue> {
269 if let Ok(int_val) = number_str.parse::<i64>() {
271 return Ok(DataValue::Integer(int_val));
272 }
273
274 if let Ok(float_val) = number_str.parse::<f64>() {
276 return Ok(DataValue::Float(float_val));
277 }
278
279 Err(anyhow!("Invalid number literal: {}", number_str))
280 }
281
282 fn evaluate_binary_op(
284 &mut self,
285 left: &SqlExpression,
286 op: &str,
287 right: &SqlExpression,
288 row_index: usize,
289 ) -> Result<DataValue> {
290 let left_val = self.evaluate(left, row_index)?;
291 let right_val = self.evaluate(right, row_index)?;
292
293 debug!(
294 "ArithmeticEvaluator: {} {} {}",
295 self.format_value(&left_val),
296 op,
297 self.format_value(&right_val)
298 );
299
300 match op {
301 "+" => self.add_values(&left_val, &right_val),
302 "-" => self.subtract_values(&left_val, &right_val),
303 "*" => self.multiply_values(&left_val, &right_val),
304 "/" => self.divide_values(&left_val, &right_val),
305 "%" => {
306 let args = vec![left.clone(), right.clone()];
308 self.evaluate_function("MOD", &args, row_index)
309 }
310 ">" | "<" | ">=" | "<=" | "=" | "!=" | "<>" => {
313 let result = compare_with_op(&left_val, &right_val, op, false);
314 Ok(DataValue::Boolean(result))
315 }
316 "IS NULL" => Ok(DataValue::Boolean(matches!(left_val, DataValue::Null))),
318 "IS NOT NULL" => Ok(DataValue::Boolean(!matches!(left_val, DataValue::Null))),
319 "AND" => {
321 let left_bool = self.to_bool(&left_val)?;
322 let right_bool = self.to_bool(&right_val)?;
323 Ok(DataValue::Boolean(left_bool && right_bool))
324 }
325 "OR" => {
326 let left_bool = self.to_bool(&left_val)?;
327 let right_bool = self.to_bool(&right_val)?;
328 Ok(DataValue::Boolean(left_bool || right_bool))
329 }
330 _ => Err(anyhow!("Unsupported arithmetic operator: {}", op)),
331 }
332 }
333
334 fn add_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
336 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
338 return Ok(DataValue::Null);
339 }
340
341 match (left, right) {
342 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a + b)),
343 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 + b)),
344 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a + *b as f64)),
345 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a + b)),
346 _ => Err(anyhow!("Cannot add {:?} and {:?}", left, right)),
347 }
348 }
349
350 fn subtract_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
352 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
354 return Ok(DataValue::Null);
355 }
356
357 match (left, right) {
358 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a - b)),
359 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 - b)),
360 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a - *b as f64)),
361 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a - b)),
362 _ => Err(anyhow!("Cannot subtract {:?} and {:?}", left, right)),
363 }
364 }
365
366 fn multiply_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
368 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
370 return Ok(DataValue::Null);
371 }
372
373 match (left, right) {
374 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a * b)),
375 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 * b)),
376 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a * *b as f64)),
377 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a * b)),
378 _ => Err(anyhow!("Cannot multiply {:?} and {:?}", left, right)),
379 }
380 }
381
382 fn divide_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
384 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
386 return Ok(DataValue::Null);
387 }
388
389 let is_zero = match right {
391 DataValue::Integer(0) => true,
392 DataValue::Float(f) if *f == 0.0 => true, _ => false,
394 };
395
396 if is_zero {
397 return Err(anyhow!("Division by zero"));
398 }
399
400 match (left, right) {
401 (DataValue::Integer(a), DataValue::Integer(b)) => {
402 if a % b == 0 {
404 Ok(DataValue::Integer(a / b))
405 } else {
406 Ok(DataValue::Float(*a as f64 / *b as f64))
407 }
408 }
409 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 / b)),
410 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a / *b as f64)),
411 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a / b)),
412 _ => Err(anyhow!("Cannot divide {:?} and {:?}", left, right)),
413 }
414 }
415
416 fn format_value(&self, value: &DataValue) -> String {
418 match value {
419 DataValue::Integer(i) => i.to_string(),
420 DataValue::Float(f) => f.to_string(),
421 DataValue::String(s) => format!("'{s}'"),
422 _ => format!("{value:?}"),
423 }
424 }
425
426 fn to_bool(&self, value: &DataValue) -> Result<bool> {
428 match value {
429 DataValue::Boolean(b) => Ok(*b),
430 DataValue::Integer(i) => Ok(*i != 0),
431 DataValue::Float(f) => Ok(*f != 0.0),
432 DataValue::Null => Ok(false),
433 _ => Err(anyhow!("Cannot convert {:?} to boolean", value)),
434 }
435 }
436
437 fn evaluate_function_with_distinct(
439 &mut self,
440 name: &str,
441 args: &[SqlExpression],
442 distinct: bool,
443 row_index: usize,
444 ) -> Result<DataValue> {
445 if distinct {
447 let name_upper = name.to_uppercase();
448
449 if self.aggregate_registry.is_aggregate(&name_upper)
451 || self.new_aggregate_registry.contains(&name_upper)
452 {
453 return self.evaluate_aggregate_with_distinct(&name_upper, args, row_index);
454 } else {
455 return Err(anyhow!(
456 "DISTINCT can only be used with aggregate functions"
457 ));
458 }
459 }
460
461 self.evaluate_function(name, args, row_index)
463 }
464
465 fn evaluate_aggregate_with_distinct(
466 &mut self,
467 name: &str,
468 args: &[SqlExpression],
469 _row_index: usize,
470 ) -> Result<DataValue> {
471 let name_upper = name.to_uppercase();
472
473 if self.new_aggregate_registry.get(&name_upper).is_some() {
475 let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
476 visible.clone()
477 } else {
478 (0..self.table.rows.len()).collect()
479 };
480
481 let mut vals = Vec::new();
483 for &row_idx in &rows_to_process {
484 if !args.is_empty() {
485 let value = self.evaluate(&args[0], row_idx)?;
486 vals.push(value);
487 }
488 }
489
490 let mut seen = HashSet::new();
492 let unique_values: Vec<_> = vals
493 .into_iter()
494 .filter(|v| {
495 let key = format!("{:?}", v);
496 seen.insert(key)
497 })
498 .collect();
499
500 let agg_func = self.new_aggregate_registry.get(&name_upper).unwrap();
502 let mut state = agg_func.create_state();
503
504 for value in &unique_values {
506 state.accumulate(value)?;
507 }
508
509 return Ok(state.finalize());
510 }
511
512 if self.aggregate_registry.get(&name_upper).is_some() {
514 let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
516 visible.clone()
517 } else {
518 (0..self.table.rows.len()).collect()
519 };
520
521 if name_upper == "STRING_AGG" && args.len() >= 2 {
523 let mut state = crate::sql::aggregates::AggregateState::StringAgg(
525 if args.len() >= 2 {
527 let separator = self.evaluate(&args[1], 0)?; match separator {
529 DataValue::String(s) => crate::sql::aggregates::StringAggState::new(&s),
530 DataValue::InternedString(s) => {
531 crate::sql::aggregates::StringAggState::new(&s)
532 }
533 _ => crate::sql::aggregates::StringAggState::new(","), }
535 } else {
536 crate::sql::aggregates::StringAggState::new(",")
537 },
538 );
539
540 let mut seen_values = HashSet::new();
543
544 for &row_idx in &rows_to_process {
545 let value = self.evaluate(&args[0], row_idx)?;
546
547 if !seen_values.insert(value.clone()) {
549 continue; }
551
552 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
554 agg_func.accumulate(&mut state, &value)?;
555 }
556
557 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
559 return Ok(agg_func.finalize(state));
560 }
561
562 let mut vals = Vec::new();
565 for &row_idx in &rows_to_process {
566 if !args.is_empty() {
567 let value = self.evaluate(&args[0], row_idx)?;
568 vals.push(value);
569 }
570 }
571
572 let mut seen = HashSet::new();
574 let mut unique_values = Vec::new();
575 for value in vals {
576 if seen.insert(value.clone()) {
577 unique_values.push(value);
578 }
579 }
580
581 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
583 let mut state = agg_func.init();
584
585 for value in &unique_values {
587 agg_func.accumulate(&mut state, value)?;
588 }
589
590 return Ok(agg_func.finalize(state));
591 }
592
593 Err(anyhow!("Unknown aggregate function: {}", name))
594 }
595
596 fn evaluate_function(
597 &mut self,
598 name: &str,
599 args: &[SqlExpression],
600 row_index: usize,
601 ) -> Result<DataValue> {
602 let name_upper = name.to_uppercase();
604
605 if self.new_aggregate_registry.get(&name_upper).is_some() {
607 let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
609 visible.clone()
610 } else {
611 (0..self.table.rows.len()).collect()
612 };
613
614 let agg_func = self.new_aggregate_registry.get(&name_upper).unwrap();
616 let mut state = agg_func.create_state();
617
618 if name_upper == "COUNT" || name_upper == "COUNT_STAR" {
620 if args.is_empty()
621 || (args.len() == 1
622 && matches!(&args[0], SqlExpression::Column(col) if col.name == "*"))
623 || (args.len() == 1
624 && matches!(&args[0], SqlExpression::StringLiteral(s) if s == "*"))
625 {
626 for _ in &rows_to_process {
628 state.accumulate(&DataValue::Integer(1))?;
629 }
630 } else {
631 for &row_idx in &rows_to_process {
633 let value = self.evaluate(&args[0], row_idx)?;
634 state.accumulate(&value)?;
635 }
636 }
637 } else {
638 if !args.is_empty() {
640 for &row_idx in &rows_to_process {
641 let value = self.evaluate(&args[0], row_idx)?;
642 state.accumulate(&value)?;
643 }
644 }
645 }
646
647 return Ok(state.finalize());
648 }
649
650 if self.aggregate_registry.get(&name_upper).is_some() {
652 let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
654 visible.clone()
655 } else {
656 (0..self.table.rows.len()).collect()
657 };
658
659 if name_upper == "STRING_AGG" && args.len() >= 2 {
661 let mut state = crate::sql::aggregates::AggregateState::StringAgg(
663 if args.len() >= 2 {
665 let separator = self.evaluate(&args[1], 0)?; match separator {
667 DataValue::String(s) => crate::sql::aggregates::StringAggState::new(&s),
668 DataValue::InternedString(s) => {
669 crate::sql::aggregates::StringAggState::new(&s)
670 }
671 _ => crate::sql::aggregates::StringAggState::new(","), }
673 } else {
674 crate::sql::aggregates::StringAggState::new(",")
675 },
676 );
677
678 for &row_idx in &rows_to_process {
680 let value = self.evaluate(&args[0], row_idx)?;
681 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
683 agg_func.accumulate(&mut state, &value)?;
684 }
685
686 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
688 return Ok(agg_func.finalize(state));
689 }
690
691 let values = if !args.is_empty()
693 && !(args.len() == 1
694 && matches!(&args[0], SqlExpression::Column(c) if c.name == "*"))
695 {
696 let mut vals = Vec::new();
698 for &row_idx in &rows_to_process {
699 let value = self.evaluate(&args[0], row_idx)?;
700 vals.push(value);
701 }
702 Some(vals)
703 } else {
704 None
705 };
706
707 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
709 let mut state = agg_func.init();
710
711 if let Some(values) = values {
712 for value in &values {
714 agg_func.accumulate(&mut state, value)?;
715 }
716 } else {
717 for _ in &rows_to_process {
719 agg_func.accumulate(&mut state, &DataValue::Integer(1))?;
720 }
721 }
722
723 return Ok(agg_func.finalize(state));
724 }
725
726 if self.function_registry.get(name).is_some() {
728 let mut evaluated_args = Vec::new();
730 for arg in args {
731 evaluated_args.push(self.evaluate(arg, row_index)?);
732 }
733
734 let func = self.function_registry.get(name).unwrap();
736 return func.evaluate(&evaluated_args);
737 }
738
739 Err(anyhow!("Unknown function: {}", name))
741 }
742
743 fn get_or_create_window_context(&mut self, spec: &WindowSpec) -> Result<Arc<WindowContext>> {
745 let key = format!("{:?}", spec);
747
748 if let Some(context) = self.window_contexts.get(&key) {
749 return Ok(Arc::clone(context));
750 }
751
752 let data_view = if let Some(ref _visible_rows) = self.visible_rows {
754 let view = DataView::new(Arc::new(self.table.clone()));
756 view
759 } else {
760 DataView::new(Arc::new(self.table.clone()))
761 };
762
763 let context = WindowContext::new_with_spec(Arc::new(data_view), spec.clone())?;
765
766 let context = Arc::new(context);
767 self.window_contexts.insert(key, Arc::clone(&context));
768 Ok(context)
769 }
770
771 fn evaluate_window_function(
773 &mut self,
774 name: &str,
775 args: &[SqlExpression],
776 spec: &WindowSpec,
777 row_index: usize,
778 ) -> Result<DataValue> {
779 let name_upper = name.to_uppercase();
780
781 debug!("Looking for window function {} in registry", name_upper);
783 if let Some(window_fn_arc) = self.window_function_registry.get(&name_upper) {
784 debug!("Found window function {} in registry", name_upper);
785
786 let window_fn = window_fn_arc.as_ref();
788
789 window_fn.validate_args(args)?;
791
792 let transformed_spec = window_fn.transform_window_spec(spec, args)?;
794
795 let context = self.get_or_create_window_context(&transformed_spec)?;
797
798 struct EvaluatorAdapter<'a, 'b> {
800 evaluator: &'a mut ArithmeticEvaluator<'b>,
801 row_index: usize,
802 }
803
804 impl<'a, 'b> ExpressionEvaluator for EvaluatorAdapter<'a, 'b> {
805 fn evaluate(
806 &mut self,
807 expr: &SqlExpression,
808 _row_index: usize,
809 ) -> Result<DataValue> {
810 self.evaluator.evaluate(expr, self.row_index)
811 }
812 }
813
814 let mut adapter = EvaluatorAdapter {
815 evaluator: self,
816 row_index,
817 };
818
819 return window_fn.compute(&context, row_index, args, &mut adapter);
821 }
822
823 let context = self.get_or_create_window_context(spec)?;
825
826 match name_upper.as_str() {
827 "LAG" => {
828 if args.is_empty() {
830 return Err(anyhow!("LAG requires at least 1 argument"));
831 }
832
833 let column = match &args[0] {
835 SqlExpression::Column(col) => col.clone(),
836 _ => return Err(anyhow!("LAG first argument must be a column")),
837 };
838
839 let offset = if args.len() > 1 {
841 match self.evaluate(&args[1], row_index)? {
842 DataValue::Integer(i) => i as i32,
843 _ => return Err(anyhow!("LAG offset must be an integer")),
844 }
845 } else {
846 1
847 };
848
849 Ok(context
851 .get_offset_value(row_index, -offset, &column.name)
852 .unwrap_or(DataValue::Null))
853 }
854 "LEAD" => {
855 if args.is_empty() {
857 return Err(anyhow!("LEAD requires at least 1 argument"));
858 }
859
860 let column = match &args[0] {
862 SqlExpression::Column(col) => col.clone(),
863 _ => return Err(anyhow!("LEAD first argument must be a column")),
864 };
865
866 let offset = if args.len() > 1 {
868 match self.evaluate(&args[1], row_index)? {
869 DataValue::Integer(i) => i as i32,
870 _ => return Err(anyhow!("LEAD offset must be an integer")),
871 }
872 } else {
873 1
874 };
875
876 Ok(context
878 .get_offset_value(row_index, offset, &column.name)
879 .unwrap_or(DataValue::Null))
880 }
881 "ROW_NUMBER" => {
882 Ok(DataValue::Integer(context.get_row_number(row_index) as i64))
884 }
885 "FIRST_VALUE" => {
886 if args.is_empty() {
888 return Err(anyhow!("FIRST_VALUE requires 1 argument"));
889 }
890
891 let column = match &args[0] {
892 SqlExpression::Column(col) => col.clone(),
893 _ => return Err(anyhow!("FIRST_VALUE argument must be a column")),
894 };
895
896 if context.has_frame() {
898 Ok(context
899 .get_frame_first_value(row_index, &column.name)
900 .unwrap_or(DataValue::Null))
901 } else {
902 Ok(context
903 .get_first_value(row_index, &column.name)
904 .unwrap_or(DataValue::Null))
905 }
906 }
907 "LAST_VALUE" => {
908 if args.is_empty() {
910 return Err(anyhow!("LAST_VALUE requires 1 argument"));
911 }
912
913 let column = match &args[0] {
914 SqlExpression::Column(col) => col.clone(),
915 _ => return Err(anyhow!("LAST_VALUE argument must be a column")),
916 };
917
918 if context.has_frame() {
920 Ok(context
921 .get_frame_last_value(row_index, &column.name)
922 .unwrap_or(DataValue::Null))
923 } else {
924 Ok(context
925 .get_last_value(row_index, &column.name)
926 .unwrap_or(DataValue::Null))
927 }
928 }
929 "SUM" => {
930 if args.is_empty() {
932 return Err(anyhow!("SUM requires 1 argument"));
933 }
934
935 let column = match &args[0] {
936 SqlExpression::Column(col) => col.clone(),
937 _ => return Err(anyhow!("SUM argument must be a column")),
938 };
939
940 if context.has_frame() {
942 Ok(context
943 .get_frame_sum(row_index, &column.name)
944 .unwrap_or(DataValue::Null))
945 } else {
946 Ok(context
947 .get_partition_sum(row_index, &column.name)
948 .unwrap_or(DataValue::Null))
949 }
950 }
951 "AVG" => {
952 if args.is_empty() {
954 return Err(anyhow!("AVG requires 1 argument"));
955 }
956
957 let column = match &args[0] {
958 SqlExpression::Column(col) => col.clone(),
959 _ => return Err(anyhow!("AVG argument must be a column")),
960 };
961
962 Ok(context
963 .get_frame_avg(row_index, &column.name)
964 .unwrap_or(DataValue::Null))
965 }
966 "STDDEV" | "STDEV" => {
967 if args.is_empty() {
969 return Err(anyhow!("STDDEV requires 1 argument"));
970 }
971
972 let column = match &args[0] {
973 SqlExpression::Column(col) => col.clone(),
974 _ => return Err(anyhow!("STDDEV argument must be a column")),
975 };
976
977 Ok(context
978 .get_frame_stddev(row_index, &column.name)
979 .unwrap_or(DataValue::Null))
980 }
981 "VARIANCE" | "VAR" => {
982 if args.is_empty() {
984 return Err(anyhow!("VARIANCE requires 1 argument"));
985 }
986
987 let column = match &args[0] {
988 SqlExpression::Column(col) => col.clone(),
989 _ => return Err(anyhow!("VARIANCE argument must be a column")),
990 };
991
992 Ok(context
993 .get_frame_variance(row_index, &column.name)
994 .unwrap_or(DataValue::Null))
995 }
996 "MIN" => {
997 if args.is_empty() {
999 return Err(anyhow!("MIN requires 1 argument"));
1000 }
1001
1002 let column = match &args[0] {
1003 SqlExpression::Column(col) => col.clone(),
1004 _ => return Err(anyhow!("MIN argument must be a column")),
1005 };
1006
1007 let frame_rows = context.get_frame_rows(row_index);
1008 if frame_rows.is_empty() {
1009 return Ok(DataValue::Null);
1010 }
1011
1012 let source_table = context.source();
1013 let col_idx = source_table
1014 .get_column_index(&column.name)
1015 .ok_or_else(|| anyhow!("Column '{}' not found", column.name))?;
1016
1017 let mut min_value: Option<DataValue> = None;
1018 for &row_idx in &frame_rows {
1019 if let Some(value) = source_table.get_value(row_idx, col_idx) {
1020 if !matches!(value, DataValue::Null) {
1021 match &min_value {
1022 None => min_value = Some(value.clone()),
1023 Some(current_min) => {
1024 if value < current_min {
1025 min_value = Some(value.clone());
1026 }
1027 }
1028 }
1029 }
1030 }
1031 }
1032
1033 Ok(min_value.unwrap_or(DataValue::Null))
1034 }
1035 "MAX" => {
1036 if args.is_empty() {
1038 return Err(anyhow!("MAX requires 1 argument"));
1039 }
1040
1041 let column = match &args[0] {
1042 SqlExpression::Column(col) => col.clone(),
1043 _ => return Err(anyhow!("MAX argument must be a column")),
1044 };
1045
1046 let frame_rows = context.get_frame_rows(row_index);
1047 if frame_rows.is_empty() {
1048 return Ok(DataValue::Null);
1049 }
1050
1051 let source_table = context.source();
1052 let col_idx = source_table
1053 .get_column_index(&column.name)
1054 .ok_or_else(|| anyhow!("Column '{}' not found", column.name))?;
1055
1056 let mut max_value: Option<DataValue> = None;
1057 for &row_idx in &frame_rows {
1058 if let Some(value) = source_table.get_value(row_idx, col_idx) {
1059 if !matches!(value, DataValue::Null) {
1060 match &max_value {
1061 None => max_value = Some(value.clone()),
1062 Some(current_max) => {
1063 if value > current_max {
1064 max_value = Some(value.clone());
1065 }
1066 }
1067 }
1068 }
1069 }
1070 }
1071
1072 Ok(max_value.unwrap_or(DataValue::Null))
1073 }
1074 "COUNT" => {
1075 if args.is_empty() {
1079 if context.has_frame() {
1081 Ok(context
1082 .get_frame_count(row_index, None)
1083 .unwrap_or(DataValue::Null))
1084 } else {
1085 Ok(context
1086 .get_partition_count(row_index, None)
1087 .unwrap_or(DataValue::Null))
1088 }
1089 } else {
1090 let column = match &args[0] {
1092 SqlExpression::Column(col) => {
1093 if col.name == "*" {
1094 if context.has_frame() {
1096 return Ok(context
1097 .get_frame_count(row_index, None)
1098 .unwrap_or(DataValue::Null));
1099 } else {
1100 return Ok(context
1101 .get_partition_count(row_index, None)
1102 .unwrap_or(DataValue::Null));
1103 }
1104 }
1105 col.clone()
1106 }
1107 SqlExpression::StringLiteral(s) if s == "*" => {
1108 if context.has_frame() {
1110 return Ok(context
1111 .get_frame_count(row_index, None)
1112 .unwrap_or(DataValue::Null));
1113 } else {
1114 return Ok(context
1115 .get_partition_count(row_index, None)
1116 .unwrap_or(DataValue::Null));
1117 }
1118 }
1119 _ => return Err(anyhow!("COUNT argument must be a column or *")),
1120 };
1121
1122 if context.has_frame() {
1124 Ok(context
1125 .get_frame_count(row_index, Some(&column.name))
1126 .unwrap_or(DataValue::Null))
1127 } else {
1128 Ok(context
1129 .get_partition_count(row_index, Some(&column.name))
1130 .unwrap_or(DataValue::Null))
1131 }
1132 }
1133 }
1134 _ => Err(anyhow!("Unknown window function: {}", name)),
1135 }
1136 }
1137
1138 fn evaluate_method_call(
1140 &mut self,
1141 object: &str,
1142 method: &str,
1143 args: &[SqlExpression],
1144 row_index: usize,
1145 ) -> Result<DataValue> {
1146 let col_index = self.table.get_column_index(object).ok_or_else(|| {
1148 let suggestion = self.find_similar_column(object);
1149 match suggestion {
1150 Some(similar) => {
1151 anyhow!("Column '{}' not found. Did you mean '{}'?", object, similar)
1152 }
1153 None => anyhow!("Column '{}' not found", object),
1154 }
1155 })?;
1156
1157 let cell_value = self.table.get_value(row_index, col_index).cloned();
1158
1159 self.evaluate_method_on_value(
1160 &cell_value.unwrap_or(DataValue::Null),
1161 method,
1162 args,
1163 row_index,
1164 )
1165 }
1166
1167 fn evaluate_method_on_value(
1169 &mut self,
1170 value: &DataValue,
1171 method: &str,
1172 args: &[SqlExpression],
1173 row_index: usize,
1174 ) -> Result<DataValue> {
1175 let function_name = match method.to_lowercase().as_str() {
1180 "trim" => "TRIM",
1181 "trimstart" | "trimbegin" => "TRIMSTART",
1182 "trimend" => "TRIMEND",
1183 "length" | "len" => "LENGTH",
1184 "contains" => "CONTAINS",
1185 "startswith" => "STARTSWITH",
1186 "endswith" => "ENDSWITH",
1187 "indexof" => "INDEXOF",
1188 _ => method, };
1190
1191 if self.function_registry.get(function_name).is_some() {
1193 debug!(
1194 "Proxying method '{}' through function registry as '{}'",
1195 method, function_name
1196 );
1197
1198 let mut func_args = vec![value.clone()];
1200
1201 for arg in args {
1203 func_args.push(self.evaluate(arg, row_index)?);
1204 }
1205
1206 let func = self.function_registry.get(function_name).unwrap();
1208 return func.evaluate(&func_args);
1209 }
1210
1211 Err(anyhow!(
1214 "Method '{}' not found. It should be registered in the function registry.",
1215 method
1216 ))
1217 }
1218
1219 fn evaluate_case_expression(
1221 &mut self,
1222 when_branches: &[crate::sql::recursive_parser::WhenBranch],
1223 else_branch: &Option<Box<SqlExpression>>,
1224 row_index: usize,
1225 ) -> Result<DataValue> {
1226 debug!(
1227 "ArithmeticEvaluator: evaluating CASE expression for row {}",
1228 row_index
1229 );
1230
1231 for branch in when_branches {
1233 let condition_result = self.evaluate_condition_as_bool(&branch.condition, row_index)?;
1235
1236 if condition_result {
1237 debug!("CASE: WHEN condition matched, evaluating result expression");
1238 return self.evaluate(&branch.result, row_index);
1239 }
1240 }
1241
1242 if let Some(else_expr) = else_branch {
1244 debug!("CASE: No WHEN matched, evaluating ELSE expression");
1245 self.evaluate(else_expr, row_index)
1246 } else {
1247 debug!("CASE: No WHEN matched and no ELSE, returning NULL");
1248 Ok(DataValue::Null)
1249 }
1250 }
1251
1252 fn evaluate_simple_case_expression(
1254 &mut self,
1255 expr: &Box<SqlExpression>,
1256 when_branches: &[crate::sql::parser::ast::SimpleWhenBranch],
1257 else_branch: &Option<Box<SqlExpression>>,
1258 row_index: usize,
1259 ) -> Result<DataValue> {
1260 debug!(
1261 "ArithmeticEvaluator: evaluating simple CASE expression for row {}",
1262 row_index
1263 );
1264
1265 let case_value = self.evaluate(expr, row_index)?;
1267 debug!("Simple CASE: evaluated expression to {:?}", case_value);
1268
1269 for branch in when_branches {
1271 let when_value = self.evaluate(&branch.value, row_index)?;
1273
1274 if self.values_equal(&case_value, &when_value)? {
1276 debug!("Simple CASE: WHEN value matched, evaluating result expression");
1277 return self.evaluate(&branch.result, row_index);
1278 }
1279 }
1280
1281 if let Some(else_expr) = else_branch {
1283 debug!("Simple CASE: No WHEN matched, evaluating ELSE expression");
1284 self.evaluate(else_expr, row_index)
1285 } else {
1286 debug!("Simple CASE: No WHEN matched and no ELSE, returning NULL");
1287 Ok(DataValue::Null)
1288 }
1289 }
1290
1291 fn values_equal(&self, left: &DataValue, right: &DataValue) -> Result<bool> {
1293 match (left, right) {
1294 (DataValue::Null, DataValue::Null) => Ok(true),
1295 (DataValue::Null, _) | (_, DataValue::Null) => Ok(false),
1296 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(a == b),
1297 (DataValue::Float(a), DataValue::Float(b)) => Ok((a - b).abs() < f64::EPSILON),
1298 (DataValue::String(a), DataValue::String(b)) => Ok(a == b),
1299 (DataValue::Boolean(a), DataValue::Boolean(b)) => Ok(a == b),
1300 (DataValue::DateTime(a), DataValue::DateTime(b)) => Ok(a == b),
1301 (DataValue::Integer(a), DataValue::Float(b)) => {
1303 Ok((*a as f64 - b).abs() < f64::EPSILON)
1304 }
1305 (DataValue::Float(a), DataValue::Integer(b)) => {
1306 Ok((a - *b as f64).abs() < f64::EPSILON)
1307 }
1308 _ => Ok(false),
1309 }
1310 }
1311
1312 fn evaluate_condition_as_bool(
1314 &mut self,
1315 expr: &SqlExpression,
1316 row_index: usize,
1317 ) -> Result<bool> {
1318 let value = self.evaluate(expr, row_index)?;
1319
1320 match value {
1321 DataValue::Boolean(b) => Ok(b),
1322 DataValue::Integer(i) => Ok(i != 0),
1323 DataValue::Float(f) => Ok(f != 0.0),
1324 DataValue::Null => Ok(false),
1325 DataValue::String(s) => Ok(!s.is_empty()),
1326 DataValue::InternedString(s) => Ok(!s.is_empty()),
1327 _ => Ok(true), }
1329 }
1330}
1331
1332#[cfg(test)]
1333mod tests {
1334 use super::*;
1335 use crate::data::datatable::{DataColumn, DataRow};
1336
1337 fn create_test_table() -> DataTable {
1338 let mut table = DataTable::new("test");
1339 table.add_column(DataColumn::new("a"));
1340 table.add_column(DataColumn::new("b"));
1341 table.add_column(DataColumn::new("c"));
1342
1343 table
1344 .add_row(DataRow::new(vec![
1345 DataValue::Integer(10),
1346 DataValue::Float(2.5),
1347 DataValue::Integer(4),
1348 ]))
1349 .unwrap();
1350
1351 table
1352 }
1353
1354 #[test]
1355 fn test_evaluate_column() {
1356 let table = create_test_table();
1357 let mut evaluator = ArithmeticEvaluator::new(&table);
1358
1359 let expr = SqlExpression::Column(ColumnRef::unquoted("a".to_string()));
1360 let result = evaluator.evaluate(&expr, 0).unwrap();
1361 assert_eq!(result, DataValue::Integer(10));
1362 }
1363
1364 #[test]
1365 fn test_evaluate_number_literal() {
1366 let table = create_test_table();
1367 let mut evaluator = ArithmeticEvaluator::new(&table);
1368
1369 let expr = SqlExpression::NumberLiteral("42".to_string());
1370 let result = evaluator.evaluate(&expr, 0).unwrap();
1371 assert_eq!(result, DataValue::Integer(42));
1372
1373 let expr = SqlExpression::NumberLiteral("3.14".to_string());
1374 let result = evaluator.evaluate(&expr, 0).unwrap();
1375 assert_eq!(result, DataValue::Float(3.14));
1376 }
1377
1378 #[test]
1379 fn test_add_values() {
1380 let table = create_test_table();
1381 let mut evaluator = ArithmeticEvaluator::new(&table);
1382
1383 let result = evaluator
1385 .add_values(&DataValue::Integer(5), &DataValue::Integer(3))
1386 .unwrap();
1387 assert_eq!(result, DataValue::Integer(8));
1388
1389 let result = evaluator
1391 .add_values(&DataValue::Integer(5), &DataValue::Float(2.5))
1392 .unwrap();
1393 assert_eq!(result, DataValue::Float(7.5));
1394 }
1395
1396 #[test]
1397 fn test_multiply_values() {
1398 let table = create_test_table();
1399 let mut evaluator = ArithmeticEvaluator::new(&table);
1400
1401 let result = evaluator
1403 .multiply_values(&DataValue::Integer(4), &DataValue::Float(2.5))
1404 .unwrap();
1405 assert_eq!(result, DataValue::Float(10.0));
1406 }
1407
1408 #[test]
1409 fn test_divide_values() {
1410 let table = create_test_table();
1411 let mut evaluator = ArithmeticEvaluator::new(&table);
1412
1413 let result = evaluator
1415 .divide_values(&DataValue::Integer(10), &DataValue::Integer(2))
1416 .unwrap();
1417 assert_eq!(result, DataValue::Integer(5));
1418
1419 let result = evaluator
1421 .divide_values(&DataValue::Integer(10), &DataValue::Integer(3))
1422 .unwrap();
1423 assert_eq!(result, DataValue::Float(10.0 / 3.0));
1424 }
1425
1426 #[test]
1427 fn test_division_by_zero() {
1428 let table = create_test_table();
1429 let mut evaluator = ArithmeticEvaluator::new(&table);
1430
1431 let result = evaluator.divide_values(&DataValue::Integer(10), &DataValue::Integer(0));
1432 assert!(result.is_err());
1433 assert!(result.unwrap_err().to_string().contains("Division by zero"));
1434 }
1435
1436 #[test]
1437 fn test_binary_op_expression() {
1438 let table = create_test_table();
1439 let mut evaluator = ArithmeticEvaluator::new(&table);
1440
1441 let expr = SqlExpression::BinaryOp {
1443 left: Box::new(SqlExpression::Column(ColumnRef::unquoted("a".to_string()))),
1444 op: "*".to_string(),
1445 right: Box::new(SqlExpression::Column(ColumnRef::unquoted("b".to_string()))),
1446 };
1447
1448 let result = evaluator.evaluate(&expr, 0).unwrap();
1449 assert_eq!(result, DataValue::Float(25.0));
1450 }
1451}