1use crate::config::global::get_date_notation;
2use crate::data::data_view::DataView;
3use crate::data::datatable::{DataTable, DataValue};
4use crate::data::value_comparisons::compare_with_op;
5use crate::sql::aggregate_functions::AggregateFunctionRegistry; use crate::sql::aggregates::AggregateRegistry; use crate::sql::functions::FunctionRegistry;
8use crate::sql::parser::ast::WindowSpec;
9use crate::sql::recursive_parser::SqlExpression;
10use crate::sql::window_context::WindowContext;
11use crate::sql::window_functions::{ExpressionEvaluator, WindowFunctionRegistry};
12use anyhow::{anyhow, Result};
13use std::collections::{HashMap, HashSet};
14use std::sync::Arc;
15use tracing::debug;
16
17pub struct ArithmeticEvaluator<'a> {
20 table: &'a DataTable,
21 date_notation: String,
22 function_registry: Arc<FunctionRegistry>,
23 aggregate_registry: Arc<AggregateRegistry>, new_aggregate_registry: Arc<AggregateFunctionRegistry>, window_function_registry: Arc<WindowFunctionRegistry>,
26 visible_rows: Option<Vec<usize>>, window_contexts: HashMap<String, Arc<WindowContext>>, table_aliases: HashMap<String, String>, }
30
31impl<'a> ArithmeticEvaluator<'a> {
32 #[must_use]
33 pub fn new(table: &'a DataTable) -> Self {
34 Self {
35 table,
36 date_notation: get_date_notation(),
37 function_registry: Arc::new(FunctionRegistry::new()),
38 aggregate_registry: Arc::new(AggregateRegistry::new()),
39 new_aggregate_registry: Arc::new(AggregateFunctionRegistry::new()),
40 window_function_registry: Arc::new(WindowFunctionRegistry::new()),
41 visible_rows: None,
42 window_contexts: HashMap::new(),
43 table_aliases: HashMap::new(),
44 }
45 }
46
47 #[must_use]
48 pub fn with_date_notation(table: &'a DataTable, date_notation: String) -> Self {
49 Self {
50 table,
51 date_notation,
52 function_registry: Arc::new(FunctionRegistry::new()),
53 aggregate_registry: Arc::new(AggregateRegistry::new()),
54 new_aggregate_registry: Arc::new(AggregateFunctionRegistry::new()),
55 window_function_registry: Arc::new(WindowFunctionRegistry::new()),
56 visible_rows: None,
57 window_contexts: HashMap::new(),
58 table_aliases: HashMap::new(),
59 }
60 }
61
62 #[must_use]
64 pub fn with_visible_rows(mut self, rows: Vec<usize>) -> Self {
65 self.visible_rows = Some(rows);
66 self
67 }
68
69 #[must_use]
71 pub fn with_table_aliases(mut self, aliases: HashMap<String, String>) -> Self {
72 self.table_aliases = aliases;
73 self
74 }
75
76 #[must_use]
77 pub fn with_date_notation_and_registry(
78 table: &'a DataTable,
79 date_notation: String,
80 function_registry: Arc<FunctionRegistry>,
81 ) -> Self {
82 Self {
83 table,
84 date_notation,
85 function_registry,
86 aggregate_registry: Arc::new(AggregateRegistry::new()),
87 new_aggregate_registry: Arc::new(AggregateFunctionRegistry::new()),
88 window_function_registry: Arc::new(WindowFunctionRegistry::new()),
89 visible_rows: None,
90 window_contexts: HashMap::new(),
91 table_aliases: HashMap::new(),
92 }
93 }
94
95 fn find_similar_column(&self, name: &str) -> Option<String> {
97 let columns = self.table.column_names();
98 let mut best_match: Option<(String, usize)> = None;
99
100 for col in columns {
101 let distance = self.edit_distance(&col.to_lowercase(), &name.to_lowercase());
102 let max_distance = if name.len() > 10 { 3 } else { 2 };
105 if distance <= max_distance {
106 match &best_match {
107 None => best_match = Some((col, distance)),
108 Some((_, best_dist)) if distance < *best_dist => {
109 best_match = Some((col, distance));
110 }
111 _ => {}
112 }
113 }
114 }
115
116 best_match.map(|(name, _)| name)
117 }
118
119 fn edit_distance(&self, s1: &str, s2: &str) -> usize {
121 crate::sql::functions::string_methods::EditDistanceFunction::calculate_edit_distance(s1, s2)
123 }
124
125 pub fn evaluate(&mut self, expr: &SqlExpression, row_index: usize) -> Result<DataValue> {
127 debug!(
128 "ArithmeticEvaluator: evaluating {:?} for row {}",
129 expr, row_index
130 );
131
132 match expr {
133 SqlExpression::Column(column_name) => self.evaluate_column(column_name, row_index),
134 SqlExpression::StringLiteral(s) => Ok(DataValue::String(s.clone())),
135 SqlExpression::BooleanLiteral(b) => Ok(DataValue::Boolean(*b)),
136 SqlExpression::NumberLiteral(n) => self.evaluate_number_literal(n),
137 SqlExpression::Null => Ok(DataValue::Null),
138 SqlExpression::BinaryOp { left, op, right } => {
139 self.evaluate_binary_op(left, op, right, row_index)
140 }
141 SqlExpression::FunctionCall {
142 name,
143 args,
144 distinct,
145 } => self.evaluate_function_with_distinct(name, args, *distinct, row_index),
146 SqlExpression::WindowFunction {
147 name,
148 args,
149 window_spec,
150 } => self.evaluate_window_function(name, args, window_spec, row_index),
151 SqlExpression::MethodCall {
152 object,
153 method,
154 args,
155 } => self.evaluate_method_call(object, method, args, row_index),
156 SqlExpression::ChainedMethodCall { base, method, args } => {
157 let base_value = self.evaluate(base, row_index)?;
159 self.evaluate_method_on_value(&base_value, method, args, row_index)
160 }
161 SqlExpression::CaseExpression {
162 when_branches,
163 else_branch,
164 } => self.evaluate_case_expression(when_branches, else_branch, row_index),
165 SqlExpression::SimpleCaseExpression {
166 expr,
167 when_branches,
168 else_branch,
169 } => self.evaluate_simple_case_expression(expr, when_branches, else_branch, row_index),
170 _ => Err(anyhow!(
171 "Unsupported expression type for arithmetic evaluation: {:?}",
172 expr
173 )),
174 }
175 }
176
177 fn evaluate_column(&self, column_name: &str, row_index: usize) -> Result<DataValue> {
179 let resolved_column = if column_name.contains('.') {
181 if let Some(dot_pos) = column_name.rfind('.') {
183 let _table_or_alias = &column_name[..dot_pos];
184 let col_name = &column_name[dot_pos + 1..];
185
186 debug!(
189 "Resolving qualified column: {} -> {}",
190 column_name, col_name
191 );
192 col_name.to_string()
193 } else {
194 column_name.to_string()
195 }
196 } else {
197 column_name.to_string()
198 };
199
200 let col_index = if let Some(idx) = self.table.get_column_index(&resolved_column) {
201 idx
202 } else if resolved_column != column_name {
203 if let Some(idx) = self.table.get_column_index(column_name) {
205 idx
206 } else {
207 let suggestion = self.find_similar_column(&resolved_column);
208 return Err(match suggestion {
209 Some(similar) => anyhow!(
210 "Column '{}' not found. Did you mean '{}'?",
211 column_name,
212 similar
213 ),
214 None => anyhow!("Column '{}' not found", column_name),
215 });
216 }
217 } else {
218 let suggestion = self.find_similar_column(&resolved_column);
219 return Err(match suggestion {
220 Some(similar) => anyhow!(
221 "Column '{}' not found. Did you mean '{}'?",
222 column_name,
223 similar
224 ),
225 None => anyhow!("Column '{}' not found", column_name),
226 });
227 };
228
229 if row_index >= self.table.row_count() {
230 return Err(anyhow!("Row index {} out of bounds", row_index));
231 }
232
233 let row = self
234 .table
235 .get_row(row_index)
236 .ok_or_else(|| anyhow!("Row {} not found", row_index))?;
237
238 let value = row
239 .get(col_index)
240 .ok_or_else(|| anyhow!("Column index {} out of bounds for row", col_index))?;
241
242 Ok(value.clone())
243 }
244
245 fn evaluate_number_literal(&self, number_str: &str) -> Result<DataValue> {
247 if let Ok(int_val) = number_str.parse::<i64>() {
249 return Ok(DataValue::Integer(int_val));
250 }
251
252 if let Ok(float_val) = number_str.parse::<f64>() {
254 return Ok(DataValue::Float(float_val));
255 }
256
257 Err(anyhow!("Invalid number literal: {}", number_str))
258 }
259
260 fn evaluate_binary_op(
262 &mut self,
263 left: &SqlExpression,
264 op: &str,
265 right: &SqlExpression,
266 row_index: usize,
267 ) -> Result<DataValue> {
268 let left_val = self.evaluate(left, row_index)?;
269 let right_val = self.evaluate(right, row_index)?;
270
271 debug!(
272 "ArithmeticEvaluator: {} {} {}",
273 self.format_value(&left_val),
274 op,
275 self.format_value(&right_val)
276 );
277
278 match op {
279 "+" => self.add_values(&left_val, &right_val),
280 "-" => self.subtract_values(&left_val, &right_val),
281 "*" => self.multiply_values(&left_val, &right_val),
282 "/" => self.divide_values(&left_val, &right_val),
283 "%" => {
284 let args = vec![left.clone(), right.clone()];
286 self.evaluate_function("MOD", &args, row_index)
287 }
288 ">" | "<" | ">=" | "<=" | "=" | "!=" | "<>" => {
291 let result = compare_with_op(&left_val, &right_val, op, false);
292 Ok(DataValue::Boolean(result))
293 }
294 "IS NULL" => Ok(DataValue::Boolean(matches!(left_val, DataValue::Null))),
296 "IS NOT NULL" => Ok(DataValue::Boolean(!matches!(left_val, DataValue::Null))),
297 "AND" => {
299 let left_bool = self.to_bool(&left_val)?;
300 let right_bool = self.to_bool(&right_val)?;
301 Ok(DataValue::Boolean(left_bool && right_bool))
302 }
303 "OR" => {
304 let left_bool = self.to_bool(&left_val)?;
305 let right_bool = self.to_bool(&right_val)?;
306 Ok(DataValue::Boolean(left_bool || right_bool))
307 }
308 _ => Err(anyhow!("Unsupported arithmetic operator: {}", op)),
309 }
310 }
311
312 fn add_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
314 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
316 return Ok(DataValue::Null);
317 }
318
319 match (left, right) {
320 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a + b)),
321 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 + b)),
322 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a + *b as f64)),
323 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a + b)),
324 _ => Err(anyhow!("Cannot add {:?} and {:?}", left, right)),
325 }
326 }
327
328 fn subtract_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
330 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
332 return Ok(DataValue::Null);
333 }
334
335 match (left, right) {
336 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a - b)),
337 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 - b)),
338 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a - *b as f64)),
339 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a - b)),
340 _ => Err(anyhow!("Cannot subtract {:?} and {:?}", left, right)),
341 }
342 }
343
344 fn multiply_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
346 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
348 return Ok(DataValue::Null);
349 }
350
351 match (left, right) {
352 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a * b)),
353 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 * b)),
354 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a * *b as f64)),
355 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a * b)),
356 _ => Err(anyhow!("Cannot multiply {:?} and {:?}", left, right)),
357 }
358 }
359
360 fn divide_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
362 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
364 return Ok(DataValue::Null);
365 }
366
367 let is_zero = match right {
369 DataValue::Integer(0) => true,
370 DataValue::Float(f) if *f == 0.0 => true, _ => false,
372 };
373
374 if is_zero {
375 return Err(anyhow!("Division by zero"));
376 }
377
378 match (left, right) {
379 (DataValue::Integer(a), DataValue::Integer(b)) => {
380 if a % b == 0 {
382 Ok(DataValue::Integer(a / b))
383 } else {
384 Ok(DataValue::Float(*a as f64 / *b as f64))
385 }
386 }
387 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 / b)),
388 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a / *b as f64)),
389 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a / b)),
390 _ => Err(anyhow!("Cannot divide {:?} and {:?}", left, right)),
391 }
392 }
393
394 fn format_value(&self, value: &DataValue) -> String {
396 match value {
397 DataValue::Integer(i) => i.to_string(),
398 DataValue::Float(f) => f.to_string(),
399 DataValue::String(s) => format!("'{s}'"),
400 _ => format!("{value:?}"),
401 }
402 }
403
404 fn to_bool(&self, value: &DataValue) -> Result<bool> {
406 match value {
407 DataValue::Boolean(b) => Ok(*b),
408 DataValue::Integer(i) => Ok(*i != 0),
409 DataValue::Float(f) => Ok(*f != 0.0),
410 DataValue::Null => Ok(false),
411 _ => Err(anyhow!("Cannot convert {:?} to boolean", value)),
412 }
413 }
414
415 fn evaluate_function_with_distinct(
417 &mut self,
418 name: &str,
419 args: &[SqlExpression],
420 distinct: bool,
421 row_index: usize,
422 ) -> Result<DataValue> {
423 if distinct {
425 let name_upper = name.to_uppercase();
426
427 if self.aggregate_registry.is_aggregate(&name_upper)
429 || self.new_aggregate_registry.contains(&name_upper)
430 {
431 return self.evaluate_aggregate_with_distinct(&name_upper, args, row_index);
432 } else {
433 return Err(anyhow!(
434 "DISTINCT can only be used with aggregate functions"
435 ));
436 }
437 }
438
439 self.evaluate_function(name, args, row_index)
441 }
442
443 fn evaluate_aggregate_with_distinct(
444 &mut self,
445 name: &str,
446 args: &[SqlExpression],
447 _row_index: usize,
448 ) -> Result<DataValue> {
449 let name_upper = name.to_uppercase();
450
451 if self.new_aggregate_registry.get(&name_upper).is_some() {
453 let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
454 visible.clone()
455 } else {
456 (0..self.table.rows.len()).collect()
457 };
458
459 let mut vals = Vec::new();
461 for &row_idx in &rows_to_process {
462 if !args.is_empty() {
463 let value = self.evaluate(&args[0], row_idx)?;
464 vals.push(value);
465 }
466 }
467
468 let mut seen = HashSet::new();
470 let unique_values: Vec<_> = vals
471 .into_iter()
472 .filter(|v| {
473 let key = format!("{:?}", v);
474 seen.insert(key)
475 })
476 .collect();
477
478 let agg_func = self.new_aggregate_registry.get(&name_upper).unwrap();
480 let mut state = agg_func.create_state();
481
482 for value in &unique_values {
484 state.accumulate(value)?;
485 }
486
487 return Ok(state.finalize());
488 }
489
490 if self.aggregate_registry.get(&name_upper).is_some() {
492 let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
494 visible.clone()
495 } else {
496 (0..self.table.rows.len()).collect()
497 };
498
499 if name_upper == "STRING_AGG" && args.len() >= 2 {
501 let mut state = crate::sql::aggregates::AggregateState::StringAgg(
503 if args.len() >= 2 {
505 let separator = self.evaluate(&args[1], 0)?; match separator {
507 DataValue::String(s) => crate::sql::aggregates::StringAggState::new(&s),
508 DataValue::InternedString(s) => {
509 crate::sql::aggregates::StringAggState::new(&s)
510 }
511 _ => crate::sql::aggregates::StringAggState::new(","), }
513 } else {
514 crate::sql::aggregates::StringAggState::new(",")
515 },
516 );
517
518 let mut seen_values = HashSet::new();
521
522 for &row_idx in &rows_to_process {
523 let value = self.evaluate(&args[0], row_idx)?;
524
525 if !seen_values.insert(value.clone()) {
527 continue; }
529
530 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
532 agg_func.accumulate(&mut state, &value)?;
533 }
534
535 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
537 return Ok(agg_func.finalize(state));
538 }
539
540 let mut vals = Vec::new();
543 for &row_idx in &rows_to_process {
544 if !args.is_empty() {
545 let value = self.evaluate(&args[0], row_idx)?;
546 vals.push(value);
547 }
548 }
549
550 let mut seen = HashSet::new();
552 let mut unique_values = Vec::new();
553 for value in vals {
554 if seen.insert(value.clone()) {
555 unique_values.push(value);
556 }
557 }
558
559 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
561 let mut state = agg_func.init();
562
563 for value in &unique_values {
565 agg_func.accumulate(&mut state, value)?;
566 }
567
568 return Ok(agg_func.finalize(state));
569 }
570
571 Err(anyhow!("Unknown aggregate function: {}", name))
572 }
573
574 fn evaluate_function(
575 &mut self,
576 name: &str,
577 args: &[SqlExpression],
578 row_index: usize,
579 ) -> Result<DataValue> {
580 let name_upper = name.to_uppercase();
582
583 if self.new_aggregate_registry.get(&name_upper).is_some() {
585 let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
587 visible.clone()
588 } else {
589 (0..self.table.rows.len()).collect()
590 };
591
592 let agg_func = self.new_aggregate_registry.get(&name_upper).unwrap();
594 let mut state = agg_func.create_state();
595
596 if name_upper == "COUNT" || name_upper == "COUNT_STAR" {
598 if args.is_empty()
599 || (args.len() == 1
600 && matches!(&args[0], SqlExpression::Column(col) if col == "*"))
601 || (args.len() == 1
602 && matches!(&args[0], SqlExpression::StringLiteral(s) if s == "*"))
603 {
604 for _ in &rows_to_process {
606 state.accumulate(&DataValue::Integer(1))?;
607 }
608 } else {
609 for &row_idx in &rows_to_process {
611 let value = self.evaluate(&args[0], row_idx)?;
612 state.accumulate(&value)?;
613 }
614 }
615 } else {
616 if !args.is_empty() {
618 for &row_idx in &rows_to_process {
619 let value = self.evaluate(&args[0], row_idx)?;
620 state.accumulate(&value)?;
621 }
622 }
623 }
624
625 return Ok(state.finalize());
626 }
627
628 if self.aggregate_registry.get(&name_upper).is_some() {
630 let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
632 visible.clone()
633 } else {
634 (0..self.table.rows.len()).collect()
635 };
636
637 if name_upper == "STRING_AGG" && args.len() >= 2 {
639 let mut state = crate::sql::aggregates::AggregateState::StringAgg(
641 if args.len() >= 2 {
643 let separator = self.evaluate(&args[1], 0)?; match separator {
645 DataValue::String(s) => crate::sql::aggregates::StringAggState::new(&s),
646 DataValue::InternedString(s) => {
647 crate::sql::aggregates::StringAggState::new(&s)
648 }
649 _ => crate::sql::aggregates::StringAggState::new(","), }
651 } else {
652 crate::sql::aggregates::StringAggState::new(",")
653 },
654 );
655
656 for &row_idx in &rows_to_process {
658 let value = self.evaluate(&args[0], row_idx)?;
659 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
661 agg_func.accumulate(&mut state, &value)?;
662 }
663
664 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
666 return Ok(agg_func.finalize(state));
667 }
668
669 let values = if !args.is_empty()
671 && !(args.len() == 1 && matches!(&args[0], SqlExpression::Column(c) if c == "*"))
672 {
673 let mut vals = Vec::new();
675 for &row_idx in &rows_to_process {
676 let value = self.evaluate(&args[0], row_idx)?;
677 vals.push(value);
678 }
679 Some(vals)
680 } else {
681 None
682 };
683
684 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
686 let mut state = agg_func.init();
687
688 if let Some(values) = values {
689 for value in &values {
691 agg_func.accumulate(&mut state, value)?;
692 }
693 } else {
694 for _ in &rows_to_process {
696 agg_func.accumulate(&mut state, &DataValue::Integer(1))?;
697 }
698 }
699
700 return Ok(agg_func.finalize(state));
701 }
702
703 if self.function_registry.get(name).is_some() {
705 let mut evaluated_args = Vec::new();
707 for arg in args {
708 evaluated_args.push(self.evaluate(arg, row_index)?);
709 }
710
711 let func = self.function_registry.get(name).unwrap();
713 return func.evaluate(&evaluated_args);
714 }
715
716 Err(anyhow!("Unknown function: {}", name))
718 }
719
720 fn get_or_create_window_context(&mut self, spec: &WindowSpec) -> Result<Arc<WindowContext>> {
722 let key = format!("{:?}", spec);
724
725 if let Some(context) = self.window_contexts.get(&key) {
726 return Ok(Arc::clone(context));
727 }
728
729 let data_view = if let Some(ref _visible_rows) = self.visible_rows {
731 let view = DataView::new(Arc::new(self.table.clone()));
733 view
736 } else {
737 DataView::new(Arc::new(self.table.clone()))
738 };
739
740 let context = WindowContext::new_with_spec(Arc::new(data_view), spec.clone())?;
742
743 let context = Arc::new(context);
744 self.window_contexts.insert(key, Arc::clone(&context));
745 Ok(context)
746 }
747
748 fn evaluate_window_function(
750 &mut self,
751 name: &str,
752 args: &[SqlExpression],
753 spec: &WindowSpec,
754 row_index: usize,
755 ) -> Result<DataValue> {
756 let name_upper = name.to_uppercase();
757
758 debug!("Looking for window function {} in registry", name_upper);
760 if let Some(window_fn_arc) = self.window_function_registry.get(&name_upper) {
761 debug!("Found window function {} in registry", name_upper);
762
763 let window_fn = window_fn_arc.as_ref();
765
766 window_fn.validate_args(args)?;
768
769 let transformed_spec = window_fn.transform_window_spec(spec, args)?;
771
772 let context = self.get_or_create_window_context(&transformed_spec)?;
774
775 struct EvaluatorAdapter<'a, 'b> {
777 evaluator: &'a mut ArithmeticEvaluator<'b>,
778 row_index: usize,
779 }
780
781 impl<'a, 'b> ExpressionEvaluator for EvaluatorAdapter<'a, 'b> {
782 fn evaluate(
783 &mut self,
784 expr: &SqlExpression,
785 _row_index: usize,
786 ) -> Result<DataValue> {
787 self.evaluator.evaluate(expr, self.row_index)
788 }
789 }
790
791 let mut adapter = EvaluatorAdapter {
792 evaluator: self,
793 row_index,
794 };
795
796 return window_fn.compute(&context, row_index, args, &mut adapter);
798 }
799
800 let context = self.get_or_create_window_context(spec)?;
802
803 match name_upper.as_str() {
804 "LAG" => {
805 if args.is_empty() {
807 return Err(anyhow!("LAG requires at least 1 argument"));
808 }
809
810 let column = match &args[0] {
812 SqlExpression::Column(col) => col.clone(),
813 _ => return Err(anyhow!("LAG first argument must be a column")),
814 };
815
816 let offset = if args.len() > 1 {
818 match self.evaluate(&args[1], row_index)? {
819 DataValue::Integer(i) => i as i32,
820 _ => return Err(anyhow!("LAG offset must be an integer")),
821 }
822 } else {
823 1
824 };
825
826 Ok(context
828 .get_offset_value(row_index, -offset, &column)
829 .unwrap_or(DataValue::Null))
830 }
831 "LEAD" => {
832 if args.is_empty() {
834 return Err(anyhow!("LEAD requires at least 1 argument"));
835 }
836
837 let column = match &args[0] {
839 SqlExpression::Column(col) => col.clone(),
840 _ => return Err(anyhow!("LEAD first argument must be a column")),
841 };
842
843 let offset = if args.len() > 1 {
845 match self.evaluate(&args[1], row_index)? {
846 DataValue::Integer(i) => i as i32,
847 _ => return Err(anyhow!("LEAD offset must be an integer")),
848 }
849 } else {
850 1
851 };
852
853 Ok(context
855 .get_offset_value(row_index, offset, &column)
856 .unwrap_or(DataValue::Null))
857 }
858 "ROW_NUMBER" => {
859 Ok(DataValue::Integer(context.get_row_number(row_index) as i64))
861 }
862 "FIRST_VALUE" => {
863 if args.is_empty() {
865 return Err(anyhow!("FIRST_VALUE requires 1 argument"));
866 }
867
868 let column = match &args[0] {
869 SqlExpression::Column(col) => col.clone(),
870 _ => return Err(anyhow!("FIRST_VALUE argument must be a column")),
871 };
872
873 if context.has_frame() {
875 Ok(context
876 .get_frame_first_value(row_index, &column)
877 .unwrap_or(DataValue::Null))
878 } else {
879 Ok(context
880 .get_first_value(row_index, &column)
881 .unwrap_or(DataValue::Null))
882 }
883 }
884 "LAST_VALUE" => {
885 if args.is_empty() {
887 return Err(anyhow!("LAST_VALUE requires 1 argument"));
888 }
889
890 let column = match &args[0] {
891 SqlExpression::Column(col) => col.clone(),
892 _ => return Err(anyhow!("LAST_VALUE argument must be a column")),
893 };
894
895 if context.has_frame() {
897 Ok(context
898 .get_frame_last_value(row_index, &column)
899 .unwrap_or(DataValue::Null))
900 } else {
901 Ok(context
902 .get_last_value(row_index, &column)
903 .unwrap_or(DataValue::Null))
904 }
905 }
906 "SUM" => {
907 if args.is_empty() {
909 return Err(anyhow!("SUM requires 1 argument"));
910 }
911
912 let column = match &args[0] {
913 SqlExpression::Column(col) => col.clone(),
914 _ => return Err(anyhow!("SUM argument must be a column")),
915 };
916
917 if context.has_frame() {
919 Ok(context
920 .get_frame_sum(row_index, &column)
921 .unwrap_or(DataValue::Null))
922 } else {
923 Ok(context
924 .get_partition_sum(row_index, &column)
925 .unwrap_or(DataValue::Null))
926 }
927 }
928 "AVG" => {
929 if args.is_empty() {
931 return Err(anyhow!("AVG requires 1 argument"));
932 }
933
934 let column = match &args[0] {
935 SqlExpression::Column(col) => col.clone(),
936 _ => return Err(anyhow!("AVG argument must be a column")),
937 };
938
939 Ok(context
940 .get_frame_avg(row_index, &column)
941 .unwrap_or(DataValue::Null))
942 }
943 "STDDEV" | "STDEV" => {
944 if args.is_empty() {
946 return Err(anyhow!("STDDEV requires 1 argument"));
947 }
948
949 let column = match &args[0] {
950 SqlExpression::Column(col) => col.clone(),
951 _ => return Err(anyhow!("STDDEV argument must be a column")),
952 };
953
954 Ok(context
955 .get_frame_stddev(row_index, &column)
956 .unwrap_or(DataValue::Null))
957 }
958 "VARIANCE" | "VAR" => {
959 if args.is_empty() {
961 return Err(anyhow!("VARIANCE requires 1 argument"));
962 }
963
964 let column = match &args[0] {
965 SqlExpression::Column(col) => col.clone(),
966 _ => return Err(anyhow!("VARIANCE argument must be a column")),
967 };
968
969 Ok(context
970 .get_frame_variance(row_index, &column)
971 .unwrap_or(DataValue::Null))
972 }
973 "MIN" => {
974 if args.is_empty() {
976 return Err(anyhow!("MIN requires 1 argument"));
977 }
978
979 let column = match &args[0] {
980 SqlExpression::Column(col) => col.clone(),
981 _ => return Err(anyhow!("MIN argument must be a column")),
982 };
983
984 let frame_rows = context.get_frame_rows(row_index);
985 if frame_rows.is_empty() {
986 return Ok(DataValue::Null);
987 }
988
989 let source_table = context.source();
990 let col_idx = source_table
991 .get_column_index(&column)
992 .ok_or_else(|| anyhow!("Column '{}' not found", column))?;
993
994 let mut min_value: Option<DataValue> = None;
995 for &row_idx in &frame_rows {
996 if let Some(value) = source_table.get_value(row_idx, col_idx) {
997 if !matches!(value, DataValue::Null) {
998 match &min_value {
999 None => min_value = Some(value.clone()),
1000 Some(current_min) => {
1001 if value < current_min {
1002 min_value = Some(value.clone());
1003 }
1004 }
1005 }
1006 }
1007 }
1008 }
1009
1010 Ok(min_value.unwrap_or(DataValue::Null))
1011 }
1012 "MAX" => {
1013 if args.is_empty() {
1015 return Err(anyhow!("MAX requires 1 argument"));
1016 }
1017
1018 let column = match &args[0] {
1019 SqlExpression::Column(col) => col.clone(),
1020 _ => return Err(anyhow!("MAX argument must be a column")),
1021 };
1022
1023 let frame_rows = context.get_frame_rows(row_index);
1024 if frame_rows.is_empty() {
1025 return Ok(DataValue::Null);
1026 }
1027
1028 let source_table = context.source();
1029 let col_idx = source_table
1030 .get_column_index(&column)
1031 .ok_or_else(|| anyhow!("Column '{}' not found", column))?;
1032
1033 let mut max_value: Option<DataValue> = None;
1034 for &row_idx in &frame_rows {
1035 if let Some(value) = source_table.get_value(row_idx, col_idx) {
1036 if !matches!(value, DataValue::Null) {
1037 match &max_value {
1038 None => max_value = Some(value.clone()),
1039 Some(current_max) => {
1040 if value > current_max {
1041 max_value = Some(value.clone());
1042 }
1043 }
1044 }
1045 }
1046 }
1047 }
1048
1049 Ok(max_value.unwrap_or(DataValue::Null))
1050 }
1051 "COUNT" => {
1052 if args.is_empty() {
1056 if context.has_frame() {
1058 Ok(context
1059 .get_frame_count(row_index, None)
1060 .unwrap_or(DataValue::Null))
1061 } else {
1062 Ok(context
1063 .get_partition_count(row_index, None)
1064 .unwrap_or(DataValue::Null))
1065 }
1066 } else {
1067 let column = match &args[0] {
1069 SqlExpression::Column(col) => {
1070 if col == "*" {
1071 if context.has_frame() {
1073 return Ok(context
1074 .get_frame_count(row_index, None)
1075 .unwrap_or(DataValue::Null));
1076 } else {
1077 return Ok(context
1078 .get_partition_count(row_index, None)
1079 .unwrap_or(DataValue::Null));
1080 }
1081 }
1082 col.clone()
1083 }
1084 SqlExpression::StringLiteral(s) if s == "*" => {
1085 if context.has_frame() {
1087 return Ok(context
1088 .get_frame_count(row_index, None)
1089 .unwrap_or(DataValue::Null));
1090 } else {
1091 return Ok(context
1092 .get_partition_count(row_index, None)
1093 .unwrap_or(DataValue::Null));
1094 }
1095 }
1096 _ => return Err(anyhow!("COUNT argument must be a column or *")),
1097 };
1098
1099 if context.has_frame() {
1101 Ok(context
1102 .get_frame_count(row_index, Some(&column))
1103 .unwrap_or(DataValue::Null))
1104 } else {
1105 Ok(context
1106 .get_partition_count(row_index, Some(&column))
1107 .unwrap_or(DataValue::Null))
1108 }
1109 }
1110 }
1111 _ => Err(anyhow!("Unknown window function: {}", name)),
1112 }
1113 }
1114
1115 fn evaluate_method_call(
1117 &mut self,
1118 object: &str,
1119 method: &str,
1120 args: &[SqlExpression],
1121 row_index: usize,
1122 ) -> Result<DataValue> {
1123 let col_index = self.table.get_column_index(object).ok_or_else(|| {
1125 let suggestion = self.find_similar_column(object);
1126 match suggestion {
1127 Some(similar) => {
1128 anyhow!("Column '{}' not found. Did you mean '{}'?", object, similar)
1129 }
1130 None => anyhow!("Column '{}' not found", object),
1131 }
1132 })?;
1133
1134 let cell_value = self.table.get_value(row_index, col_index).cloned();
1135
1136 self.evaluate_method_on_value(
1137 &cell_value.unwrap_or(DataValue::Null),
1138 method,
1139 args,
1140 row_index,
1141 )
1142 }
1143
1144 fn evaluate_method_on_value(
1146 &mut self,
1147 value: &DataValue,
1148 method: &str,
1149 args: &[SqlExpression],
1150 row_index: usize,
1151 ) -> Result<DataValue> {
1152 let function_name = match method.to_lowercase().as_str() {
1157 "trim" => "TRIM",
1158 "trimstart" | "trimbegin" => "TRIMSTART",
1159 "trimend" => "TRIMEND",
1160 "length" | "len" => "LENGTH",
1161 "contains" => "CONTAINS",
1162 "startswith" => "STARTSWITH",
1163 "endswith" => "ENDSWITH",
1164 "indexof" => "INDEXOF",
1165 _ => method, };
1167
1168 if self.function_registry.get(function_name).is_some() {
1170 debug!(
1171 "Proxying method '{}' through function registry as '{}'",
1172 method, function_name
1173 );
1174
1175 let mut func_args = vec![value.clone()];
1177
1178 for arg in args {
1180 func_args.push(self.evaluate(arg, row_index)?);
1181 }
1182
1183 let func = self.function_registry.get(function_name).unwrap();
1185 return func.evaluate(&func_args);
1186 }
1187
1188 Err(anyhow!(
1191 "Method '{}' not found. It should be registered in the function registry.",
1192 method
1193 ))
1194 }
1195
1196 fn evaluate_case_expression(
1198 &mut self,
1199 when_branches: &[crate::sql::recursive_parser::WhenBranch],
1200 else_branch: &Option<Box<SqlExpression>>,
1201 row_index: usize,
1202 ) -> Result<DataValue> {
1203 debug!(
1204 "ArithmeticEvaluator: evaluating CASE expression for row {}",
1205 row_index
1206 );
1207
1208 for branch in when_branches {
1210 let condition_result = self.evaluate_condition_as_bool(&branch.condition, row_index)?;
1212
1213 if condition_result {
1214 debug!("CASE: WHEN condition matched, evaluating result expression");
1215 return self.evaluate(&branch.result, row_index);
1216 }
1217 }
1218
1219 if let Some(else_expr) = else_branch {
1221 debug!("CASE: No WHEN matched, evaluating ELSE expression");
1222 self.evaluate(else_expr, row_index)
1223 } else {
1224 debug!("CASE: No WHEN matched and no ELSE, returning NULL");
1225 Ok(DataValue::Null)
1226 }
1227 }
1228
1229 fn evaluate_simple_case_expression(
1231 &mut self,
1232 expr: &Box<SqlExpression>,
1233 when_branches: &[crate::sql::parser::ast::SimpleWhenBranch],
1234 else_branch: &Option<Box<SqlExpression>>,
1235 row_index: usize,
1236 ) -> Result<DataValue> {
1237 debug!(
1238 "ArithmeticEvaluator: evaluating simple CASE expression for row {}",
1239 row_index
1240 );
1241
1242 let case_value = self.evaluate(expr, row_index)?;
1244 debug!("Simple CASE: evaluated expression to {:?}", case_value);
1245
1246 for branch in when_branches {
1248 let when_value = self.evaluate(&branch.value, row_index)?;
1250
1251 if self.values_equal(&case_value, &when_value)? {
1253 debug!("Simple CASE: WHEN value matched, evaluating result expression");
1254 return self.evaluate(&branch.result, row_index);
1255 }
1256 }
1257
1258 if let Some(else_expr) = else_branch {
1260 debug!("Simple CASE: No WHEN matched, evaluating ELSE expression");
1261 self.evaluate(else_expr, row_index)
1262 } else {
1263 debug!("Simple CASE: No WHEN matched and no ELSE, returning NULL");
1264 Ok(DataValue::Null)
1265 }
1266 }
1267
1268 fn values_equal(&self, left: &DataValue, right: &DataValue) -> Result<bool> {
1270 match (left, right) {
1271 (DataValue::Null, DataValue::Null) => Ok(true),
1272 (DataValue::Null, _) | (_, DataValue::Null) => Ok(false),
1273 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(a == b),
1274 (DataValue::Float(a), DataValue::Float(b)) => Ok((a - b).abs() < f64::EPSILON),
1275 (DataValue::String(a), DataValue::String(b)) => Ok(a == b),
1276 (DataValue::Boolean(a), DataValue::Boolean(b)) => Ok(a == b),
1277 (DataValue::DateTime(a), DataValue::DateTime(b)) => Ok(a == b),
1278 (DataValue::Integer(a), DataValue::Float(b)) => {
1280 Ok((*a as f64 - b).abs() < f64::EPSILON)
1281 }
1282 (DataValue::Float(a), DataValue::Integer(b)) => {
1283 Ok((a - *b as f64).abs() < f64::EPSILON)
1284 }
1285 _ => Ok(false),
1286 }
1287 }
1288
1289 fn evaluate_condition_as_bool(
1291 &mut self,
1292 expr: &SqlExpression,
1293 row_index: usize,
1294 ) -> Result<bool> {
1295 let value = self.evaluate(expr, row_index)?;
1296
1297 match value {
1298 DataValue::Boolean(b) => Ok(b),
1299 DataValue::Integer(i) => Ok(i != 0),
1300 DataValue::Float(f) => Ok(f != 0.0),
1301 DataValue::Null => Ok(false),
1302 DataValue::String(s) => Ok(!s.is_empty()),
1303 DataValue::InternedString(s) => Ok(!s.is_empty()),
1304 _ => Ok(true), }
1306 }
1307}
1308
1309#[cfg(test)]
1310mod tests {
1311 use super::*;
1312 use crate::data::datatable::{DataColumn, DataRow};
1313
1314 fn create_test_table() -> DataTable {
1315 let mut table = DataTable::new("test");
1316 table.add_column(DataColumn::new("a"));
1317 table.add_column(DataColumn::new("b"));
1318 table.add_column(DataColumn::new("c"));
1319
1320 table
1321 .add_row(DataRow::new(vec![
1322 DataValue::Integer(10),
1323 DataValue::Float(2.5),
1324 DataValue::Integer(4),
1325 ]))
1326 .unwrap();
1327
1328 table
1329 }
1330
1331 #[test]
1332 fn test_evaluate_column() {
1333 let table = create_test_table();
1334 let mut evaluator = ArithmeticEvaluator::new(&table);
1335
1336 let expr = SqlExpression::Column("a".to_string());
1337 let result = evaluator.evaluate(&expr, 0).unwrap();
1338 assert_eq!(result, DataValue::Integer(10));
1339 }
1340
1341 #[test]
1342 fn test_evaluate_number_literal() {
1343 let table = create_test_table();
1344 let mut evaluator = ArithmeticEvaluator::new(&table);
1345
1346 let expr = SqlExpression::NumberLiteral("42".to_string());
1347 let result = evaluator.evaluate(&expr, 0).unwrap();
1348 assert_eq!(result, DataValue::Integer(42));
1349
1350 let expr = SqlExpression::NumberLiteral("3.14".to_string());
1351 let result = evaluator.evaluate(&expr, 0).unwrap();
1352 assert_eq!(result, DataValue::Float(3.14));
1353 }
1354
1355 #[test]
1356 fn test_add_values() {
1357 let table = create_test_table();
1358 let mut evaluator = ArithmeticEvaluator::new(&table);
1359
1360 let result = evaluator
1362 .add_values(&DataValue::Integer(5), &DataValue::Integer(3))
1363 .unwrap();
1364 assert_eq!(result, DataValue::Integer(8));
1365
1366 let result = evaluator
1368 .add_values(&DataValue::Integer(5), &DataValue::Float(2.5))
1369 .unwrap();
1370 assert_eq!(result, DataValue::Float(7.5));
1371 }
1372
1373 #[test]
1374 fn test_multiply_values() {
1375 let table = create_test_table();
1376 let mut evaluator = ArithmeticEvaluator::new(&table);
1377
1378 let result = evaluator
1380 .multiply_values(&DataValue::Integer(4), &DataValue::Float(2.5))
1381 .unwrap();
1382 assert_eq!(result, DataValue::Float(10.0));
1383 }
1384
1385 #[test]
1386 fn test_divide_values() {
1387 let table = create_test_table();
1388 let mut evaluator = ArithmeticEvaluator::new(&table);
1389
1390 let result = evaluator
1392 .divide_values(&DataValue::Integer(10), &DataValue::Integer(2))
1393 .unwrap();
1394 assert_eq!(result, DataValue::Integer(5));
1395
1396 let result = evaluator
1398 .divide_values(&DataValue::Integer(10), &DataValue::Integer(3))
1399 .unwrap();
1400 assert_eq!(result, DataValue::Float(10.0 / 3.0));
1401 }
1402
1403 #[test]
1404 fn test_division_by_zero() {
1405 let table = create_test_table();
1406 let mut evaluator = ArithmeticEvaluator::new(&table);
1407
1408 let result = evaluator.divide_values(&DataValue::Integer(10), &DataValue::Integer(0));
1409 assert!(result.is_err());
1410 assert!(result.unwrap_err().to_string().contains("Division by zero"));
1411 }
1412
1413 #[test]
1414 fn test_binary_op_expression() {
1415 let table = create_test_table();
1416 let mut evaluator = ArithmeticEvaluator::new(&table);
1417
1418 let expr = SqlExpression::BinaryOp {
1420 left: Box::new(SqlExpression::Column("a".to_string())),
1421 op: "*".to_string(),
1422 right: Box::new(SqlExpression::Column("b".to_string())),
1423 };
1424
1425 let result = evaluator.evaluate(&expr, 0).unwrap();
1426 assert_eq!(result, DataValue::Float(25.0));
1427 }
1428}