1use crate::config::global::get_date_notation;
2use crate::data::data_view::DataView;
3use crate::data::datatable::{DataTable, DataValue};
4use crate::data::value_comparisons::compare_with_op;
5use crate::sql::aggregate_functions::AggregateFunctionRegistry; use crate::sql::aggregates::AggregateRegistry; use crate::sql::functions::FunctionRegistry;
8use crate::sql::parser::ast::{ColumnRef, WindowSpec};
9use crate::sql::recursive_parser::SqlExpression;
10use crate::sql::window_context::WindowContext;
11use crate::sql::window_functions::{ExpressionEvaluator, WindowFunctionRegistry};
12use anyhow::{anyhow, Result};
13use std::collections::{HashMap, HashSet};
14use std::sync::Arc;
15use tracing::debug;
16
17pub struct ArithmeticEvaluator<'a> {
20 table: &'a DataTable,
21 date_notation: String,
22 function_registry: Arc<FunctionRegistry>,
23 aggregate_registry: Arc<AggregateRegistry>, new_aggregate_registry: Arc<AggregateFunctionRegistry>, window_function_registry: Arc<WindowFunctionRegistry>,
26 visible_rows: Option<Vec<usize>>, window_contexts: HashMap<String, Arc<WindowContext>>, table_aliases: HashMap<String, String>, }
30
31impl<'a> ArithmeticEvaluator<'a> {
32 #[must_use]
33 pub fn new(table: &'a DataTable) -> Self {
34 Self {
35 table,
36 date_notation: get_date_notation(),
37 function_registry: Arc::new(FunctionRegistry::new()),
38 aggregate_registry: Arc::new(AggregateRegistry::new()),
39 new_aggregate_registry: Arc::new(AggregateFunctionRegistry::new()),
40 window_function_registry: Arc::new(WindowFunctionRegistry::new()),
41 visible_rows: None,
42 window_contexts: HashMap::new(),
43 table_aliases: HashMap::new(),
44 }
45 }
46
47 #[must_use]
48 pub fn with_date_notation(table: &'a DataTable, date_notation: String) -> Self {
49 Self {
50 table,
51 date_notation,
52 function_registry: Arc::new(FunctionRegistry::new()),
53 aggregate_registry: Arc::new(AggregateRegistry::new()),
54 new_aggregate_registry: Arc::new(AggregateFunctionRegistry::new()),
55 window_function_registry: Arc::new(WindowFunctionRegistry::new()),
56 visible_rows: None,
57 window_contexts: HashMap::new(),
58 table_aliases: HashMap::new(),
59 }
60 }
61
62 #[must_use]
64 pub fn with_visible_rows(mut self, rows: Vec<usize>) -> Self {
65 self.visible_rows = Some(rows);
66 self
67 }
68
69 #[must_use]
71 pub fn with_table_aliases(mut self, aliases: HashMap<String, String>) -> Self {
72 self.table_aliases = aliases;
73 self
74 }
75
76 #[must_use]
77 pub fn with_date_notation_and_registry(
78 table: &'a DataTable,
79 date_notation: String,
80 function_registry: Arc<FunctionRegistry>,
81 ) -> Self {
82 Self {
83 table,
84 date_notation,
85 function_registry,
86 aggregate_registry: Arc::new(AggregateRegistry::new()),
87 new_aggregate_registry: Arc::new(AggregateFunctionRegistry::new()),
88 window_function_registry: Arc::new(WindowFunctionRegistry::new()),
89 visible_rows: None,
90 window_contexts: HashMap::new(),
91 table_aliases: HashMap::new(),
92 }
93 }
94
95 fn find_similar_column(&self, name: &str) -> Option<String> {
97 let columns = self.table.column_names();
98 let mut best_match: Option<(String, usize)> = None;
99
100 for col in columns {
101 let distance = self.edit_distance(&col.to_lowercase(), &name.to_lowercase());
102 let max_distance = if name.len() > 10 { 3 } else { 2 };
105 if distance <= max_distance {
106 match &best_match {
107 None => best_match = Some((col, distance)),
108 Some((_, best_dist)) if distance < *best_dist => {
109 best_match = Some((col, distance));
110 }
111 _ => {}
112 }
113 }
114 }
115
116 best_match.map(|(name, _)| name)
117 }
118
119 fn edit_distance(&self, s1: &str, s2: &str) -> usize {
121 crate::sql::functions::string_methods::EditDistanceFunction::calculate_edit_distance(s1, s2)
123 }
124
125 pub fn evaluate(&mut self, expr: &SqlExpression, row_index: usize) -> Result<DataValue> {
127 debug!(
128 "ArithmeticEvaluator: evaluating {:?} for row {}",
129 expr, row_index
130 );
131
132 match expr {
133 SqlExpression::Column(column_ref) => self.evaluate_column(&column_ref.name, row_index),
134 SqlExpression::StringLiteral(s) => Ok(DataValue::String(s.clone())),
135 SqlExpression::BooleanLiteral(b) => Ok(DataValue::Boolean(*b)),
136 SqlExpression::NumberLiteral(n) => self.evaluate_number_literal(n),
137 SqlExpression::Null => Ok(DataValue::Null),
138 SqlExpression::BinaryOp { left, op, right } => {
139 self.evaluate_binary_op(left, op, right, row_index)
140 }
141 SqlExpression::FunctionCall {
142 name,
143 args,
144 distinct,
145 } => self.evaluate_function_with_distinct(name, args, *distinct, row_index),
146 SqlExpression::WindowFunction {
147 name,
148 args,
149 window_spec,
150 } => self.evaluate_window_function(name, args, window_spec, row_index),
151 SqlExpression::MethodCall {
152 object,
153 method,
154 args,
155 } => self.evaluate_method_call(object, method, args, row_index),
156 SqlExpression::ChainedMethodCall { base, method, args } => {
157 let base_value = self.evaluate(base, row_index)?;
159 self.evaluate_method_on_value(&base_value, method, args, row_index)
160 }
161 SqlExpression::CaseExpression {
162 when_branches,
163 else_branch,
164 } => self.evaluate_case_expression(when_branches, else_branch, row_index),
165 SqlExpression::SimpleCaseExpression {
166 expr,
167 when_branches,
168 else_branch,
169 } => self.evaluate_simple_case_expression(expr, when_branches, else_branch, row_index),
170 _ => Err(anyhow!(
171 "Unsupported expression type for arithmetic evaluation: {:?}",
172 expr
173 )),
174 }
175 }
176
177 fn evaluate_column(&self, column_name: &str, row_index: usize) -> Result<DataValue> {
179 let resolved_column = if column_name.contains('.') {
181 if let Some(dot_pos) = column_name.rfind('.') {
183 let _table_or_alias = &column_name[..dot_pos];
184 let col_name = &column_name[dot_pos + 1..];
185
186 debug!(
189 "Resolving qualified column: {} -> {}",
190 column_name, col_name
191 );
192 col_name.to_string()
193 } else {
194 column_name.to_string()
195 }
196 } else {
197 column_name.to_string()
198 };
199
200 let col_index = if let Some(idx) = self.table.get_column_index(&resolved_column) {
201 idx
202 } else if resolved_column != column_name {
203 if let Some(idx) = self.table.get_column_index(column_name) {
205 idx
206 } else {
207 let suggestion = self.find_similar_column(&resolved_column);
208 return Err(match suggestion {
209 Some(similar) => anyhow!(
210 "Column '{}' not found. Did you mean '{}'?",
211 column_name,
212 similar
213 ),
214 None => anyhow!("Column '{}' not found", column_name),
215 });
216 }
217 } else {
218 let suggestion = self.find_similar_column(&resolved_column);
219 return Err(match suggestion {
220 Some(similar) => anyhow!(
221 "Column '{}' not found. Did you mean '{}'?",
222 column_name,
223 similar
224 ),
225 None => anyhow!("Column '{}' not found", column_name),
226 });
227 };
228
229 if row_index >= self.table.row_count() {
230 return Err(anyhow!("Row index {} out of bounds", row_index));
231 }
232
233 let row = self
234 .table
235 .get_row(row_index)
236 .ok_or_else(|| anyhow!("Row {} not found", row_index))?;
237
238 let value = row
239 .get(col_index)
240 .ok_or_else(|| anyhow!("Column index {} out of bounds for row", col_index))?;
241
242 Ok(value.clone())
243 }
244
245 fn evaluate_number_literal(&self, number_str: &str) -> Result<DataValue> {
247 if let Ok(int_val) = number_str.parse::<i64>() {
249 return Ok(DataValue::Integer(int_val));
250 }
251
252 if let Ok(float_val) = number_str.parse::<f64>() {
254 return Ok(DataValue::Float(float_val));
255 }
256
257 Err(anyhow!("Invalid number literal: {}", number_str))
258 }
259
260 fn evaluate_binary_op(
262 &mut self,
263 left: &SqlExpression,
264 op: &str,
265 right: &SqlExpression,
266 row_index: usize,
267 ) -> Result<DataValue> {
268 let left_val = self.evaluate(left, row_index)?;
269 let right_val = self.evaluate(right, row_index)?;
270
271 debug!(
272 "ArithmeticEvaluator: {} {} {}",
273 self.format_value(&left_val),
274 op,
275 self.format_value(&right_val)
276 );
277
278 match op {
279 "+" => self.add_values(&left_val, &right_val),
280 "-" => self.subtract_values(&left_val, &right_val),
281 "*" => self.multiply_values(&left_val, &right_val),
282 "/" => self.divide_values(&left_val, &right_val),
283 "%" => {
284 let args = vec![left.clone(), right.clone()];
286 self.evaluate_function("MOD", &args, row_index)
287 }
288 ">" | "<" | ">=" | "<=" | "=" | "!=" | "<>" => {
291 let result = compare_with_op(&left_val, &right_val, op, false);
292 Ok(DataValue::Boolean(result))
293 }
294 "IS NULL" => Ok(DataValue::Boolean(matches!(left_val, DataValue::Null))),
296 "IS NOT NULL" => Ok(DataValue::Boolean(!matches!(left_val, DataValue::Null))),
297 "AND" => {
299 let left_bool = self.to_bool(&left_val)?;
300 let right_bool = self.to_bool(&right_val)?;
301 Ok(DataValue::Boolean(left_bool && right_bool))
302 }
303 "OR" => {
304 let left_bool = self.to_bool(&left_val)?;
305 let right_bool = self.to_bool(&right_val)?;
306 Ok(DataValue::Boolean(left_bool || right_bool))
307 }
308 _ => Err(anyhow!("Unsupported arithmetic operator: {}", op)),
309 }
310 }
311
312 fn add_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
314 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
316 return Ok(DataValue::Null);
317 }
318
319 match (left, right) {
320 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a + b)),
321 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 + b)),
322 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a + *b as f64)),
323 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a + b)),
324 _ => Err(anyhow!("Cannot add {:?} and {:?}", left, right)),
325 }
326 }
327
328 fn subtract_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
330 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
332 return Ok(DataValue::Null);
333 }
334
335 match (left, right) {
336 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a - b)),
337 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 - b)),
338 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a - *b as f64)),
339 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a - b)),
340 _ => Err(anyhow!("Cannot subtract {:?} and {:?}", left, right)),
341 }
342 }
343
344 fn multiply_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
346 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
348 return Ok(DataValue::Null);
349 }
350
351 match (left, right) {
352 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a * b)),
353 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 * b)),
354 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a * *b as f64)),
355 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a * b)),
356 _ => Err(anyhow!("Cannot multiply {:?} and {:?}", left, right)),
357 }
358 }
359
360 fn divide_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
362 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
364 return Ok(DataValue::Null);
365 }
366
367 let is_zero = match right {
369 DataValue::Integer(0) => true,
370 DataValue::Float(f) if *f == 0.0 => true, _ => false,
372 };
373
374 if is_zero {
375 return Err(anyhow!("Division by zero"));
376 }
377
378 match (left, right) {
379 (DataValue::Integer(a), DataValue::Integer(b)) => {
380 if a % b == 0 {
382 Ok(DataValue::Integer(a / b))
383 } else {
384 Ok(DataValue::Float(*a as f64 / *b as f64))
385 }
386 }
387 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 / b)),
388 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a / *b as f64)),
389 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a / b)),
390 _ => Err(anyhow!("Cannot divide {:?} and {:?}", left, right)),
391 }
392 }
393
394 fn format_value(&self, value: &DataValue) -> String {
396 match value {
397 DataValue::Integer(i) => i.to_string(),
398 DataValue::Float(f) => f.to_string(),
399 DataValue::String(s) => format!("'{s}'"),
400 _ => format!("{value:?}"),
401 }
402 }
403
404 fn to_bool(&self, value: &DataValue) -> Result<bool> {
406 match value {
407 DataValue::Boolean(b) => Ok(*b),
408 DataValue::Integer(i) => Ok(*i != 0),
409 DataValue::Float(f) => Ok(*f != 0.0),
410 DataValue::Null => Ok(false),
411 _ => Err(anyhow!("Cannot convert {:?} to boolean", value)),
412 }
413 }
414
415 fn evaluate_function_with_distinct(
417 &mut self,
418 name: &str,
419 args: &[SqlExpression],
420 distinct: bool,
421 row_index: usize,
422 ) -> Result<DataValue> {
423 if distinct {
425 let name_upper = name.to_uppercase();
426
427 if self.aggregate_registry.is_aggregate(&name_upper)
429 || self.new_aggregate_registry.contains(&name_upper)
430 {
431 return self.evaluate_aggregate_with_distinct(&name_upper, args, row_index);
432 } else {
433 return Err(anyhow!(
434 "DISTINCT can only be used with aggregate functions"
435 ));
436 }
437 }
438
439 self.evaluate_function(name, args, row_index)
441 }
442
443 fn evaluate_aggregate_with_distinct(
444 &mut self,
445 name: &str,
446 args: &[SqlExpression],
447 _row_index: usize,
448 ) -> Result<DataValue> {
449 let name_upper = name.to_uppercase();
450
451 if self.new_aggregate_registry.get(&name_upper).is_some() {
453 let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
454 visible.clone()
455 } else {
456 (0..self.table.rows.len()).collect()
457 };
458
459 let mut vals = Vec::new();
461 for &row_idx in &rows_to_process {
462 if !args.is_empty() {
463 let value = self.evaluate(&args[0], row_idx)?;
464 vals.push(value);
465 }
466 }
467
468 let mut seen = HashSet::new();
470 let unique_values: Vec<_> = vals
471 .into_iter()
472 .filter(|v| {
473 let key = format!("{:?}", v);
474 seen.insert(key)
475 })
476 .collect();
477
478 let agg_func = self.new_aggregate_registry.get(&name_upper).unwrap();
480 let mut state = agg_func.create_state();
481
482 for value in &unique_values {
484 state.accumulate(value)?;
485 }
486
487 return Ok(state.finalize());
488 }
489
490 if self.aggregate_registry.get(&name_upper).is_some() {
492 let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
494 visible.clone()
495 } else {
496 (0..self.table.rows.len()).collect()
497 };
498
499 if name_upper == "STRING_AGG" && args.len() >= 2 {
501 let mut state = crate::sql::aggregates::AggregateState::StringAgg(
503 if args.len() >= 2 {
505 let separator = self.evaluate(&args[1], 0)?; match separator {
507 DataValue::String(s) => crate::sql::aggregates::StringAggState::new(&s),
508 DataValue::InternedString(s) => {
509 crate::sql::aggregates::StringAggState::new(&s)
510 }
511 _ => crate::sql::aggregates::StringAggState::new(","), }
513 } else {
514 crate::sql::aggregates::StringAggState::new(",")
515 },
516 );
517
518 let mut seen_values = HashSet::new();
521
522 for &row_idx in &rows_to_process {
523 let value = self.evaluate(&args[0], row_idx)?;
524
525 if !seen_values.insert(value.clone()) {
527 continue; }
529
530 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
532 agg_func.accumulate(&mut state, &value)?;
533 }
534
535 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
537 return Ok(agg_func.finalize(state));
538 }
539
540 let mut vals = Vec::new();
543 for &row_idx in &rows_to_process {
544 if !args.is_empty() {
545 let value = self.evaluate(&args[0], row_idx)?;
546 vals.push(value);
547 }
548 }
549
550 let mut seen = HashSet::new();
552 let mut unique_values = Vec::new();
553 for value in vals {
554 if seen.insert(value.clone()) {
555 unique_values.push(value);
556 }
557 }
558
559 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
561 let mut state = agg_func.init();
562
563 for value in &unique_values {
565 agg_func.accumulate(&mut state, value)?;
566 }
567
568 return Ok(agg_func.finalize(state));
569 }
570
571 Err(anyhow!("Unknown aggregate function: {}", name))
572 }
573
574 fn evaluate_function(
575 &mut self,
576 name: &str,
577 args: &[SqlExpression],
578 row_index: usize,
579 ) -> Result<DataValue> {
580 let name_upper = name.to_uppercase();
582
583 if self.new_aggregate_registry.get(&name_upper).is_some() {
585 let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
587 visible.clone()
588 } else {
589 (0..self.table.rows.len()).collect()
590 };
591
592 let agg_func = self.new_aggregate_registry.get(&name_upper).unwrap();
594 let mut state = agg_func.create_state();
595
596 if name_upper == "COUNT" || name_upper == "COUNT_STAR" {
598 if args.is_empty()
599 || (args.len() == 1
600 && matches!(&args[0], SqlExpression::Column(col) if col.name == "*"))
601 || (args.len() == 1
602 && matches!(&args[0], SqlExpression::StringLiteral(s) if s == "*"))
603 {
604 for _ in &rows_to_process {
606 state.accumulate(&DataValue::Integer(1))?;
607 }
608 } else {
609 for &row_idx in &rows_to_process {
611 let value = self.evaluate(&args[0], row_idx)?;
612 state.accumulate(&value)?;
613 }
614 }
615 } else {
616 if !args.is_empty() {
618 for &row_idx in &rows_to_process {
619 let value = self.evaluate(&args[0], row_idx)?;
620 state.accumulate(&value)?;
621 }
622 }
623 }
624
625 return Ok(state.finalize());
626 }
627
628 if self.aggregate_registry.get(&name_upper).is_some() {
630 let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
632 visible.clone()
633 } else {
634 (0..self.table.rows.len()).collect()
635 };
636
637 if name_upper == "STRING_AGG" && args.len() >= 2 {
639 let mut state = crate::sql::aggregates::AggregateState::StringAgg(
641 if args.len() >= 2 {
643 let separator = self.evaluate(&args[1], 0)?; match separator {
645 DataValue::String(s) => crate::sql::aggregates::StringAggState::new(&s),
646 DataValue::InternedString(s) => {
647 crate::sql::aggregates::StringAggState::new(&s)
648 }
649 _ => crate::sql::aggregates::StringAggState::new(","), }
651 } else {
652 crate::sql::aggregates::StringAggState::new(",")
653 },
654 );
655
656 for &row_idx in &rows_to_process {
658 let value = self.evaluate(&args[0], row_idx)?;
659 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
661 agg_func.accumulate(&mut state, &value)?;
662 }
663
664 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
666 return Ok(agg_func.finalize(state));
667 }
668
669 let values = if !args.is_empty()
671 && !(args.len() == 1
672 && matches!(&args[0], SqlExpression::Column(c) if c.name == "*"))
673 {
674 let mut vals = Vec::new();
676 for &row_idx in &rows_to_process {
677 let value = self.evaluate(&args[0], row_idx)?;
678 vals.push(value);
679 }
680 Some(vals)
681 } else {
682 None
683 };
684
685 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
687 let mut state = agg_func.init();
688
689 if let Some(values) = values {
690 for value in &values {
692 agg_func.accumulate(&mut state, value)?;
693 }
694 } else {
695 for _ in &rows_to_process {
697 agg_func.accumulate(&mut state, &DataValue::Integer(1))?;
698 }
699 }
700
701 return Ok(agg_func.finalize(state));
702 }
703
704 if self.function_registry.get(name).is_some() {
706 let mut evaluated_args = Vec::new();
708 for arg in args {
709 evaluated_args.push(self.evaluate(arg, row_index)?);
710 }
711
712 let func = self.function_registry.get(name).unwrap();
714 return func.evaluate(&evaluated_args);
715 }
716
717 Err(anyhow!("Unknown function: {}", name))
719 }
720
721 fn get_or_create_window_context(&mut self, spec: &WindowSpec) -> Result<Arc<WindowContext>> {
723 let key = format!("{:?}", spec);
725
726 if let Some(context) = self.window_contexts.get(&key) {
727 return Ok(Arc::clone(context));
728 }
729
730 let data_view = if let Some(ref _visible_rows) = self.visible_rows {
732 let view = DataView::new(Arc::new(self.table.clone()));
734 view
737 } else {
738 DataView::new(Arc::new(self.table.clone()))
739 };
740
741 let context = WindowContext::new_with_spec(Arc::new(data_view), spec.clone())?;
743
744 let context = Arc::new(context);
745 self.window_contexts.insert(key, Arc::clone(&context));
746 Ok(context)
747 }
748
749 fn evaluate_window_function(
751 &mut self,
752 name: &str,
753 args: &[SqlExpression],
754 spec: &WindowSpec,
755 row_index: usize,
756 ) -> Result<DataValue> {
757 let name_upper = name.to_uppercase();
758
759 debug!("Looking for window function {} in registry", name_upper);
761 if let Some(window_fn_arc) = self.window_function_registry.get(&name_upper) {
762 debug!("Found window function {} in registry", name_upper);
763
764 let window_fn = window_fn_arc.as_ref();
766
767 window_fn.validate_args(args)?;
769
770 let transformed_spec = window_fn.transform_window_spec(spec, args)?;
772
773 let context = self.get_or_create_window_context(&transformed_spec)?;
775
776 struct EvaluatorAdapter<'a, 'b> {
778 evaluator: &'a mut ArithmeticEvaluator<'b>,
779 row_index: usize,
780 }
781
782 impl<'a, 'b> ExpressionEvaluator for EvaluatorAdapter<'a, 'b> {
783 fn evaluate(
784 &mut self,
785 expr: &SqlExpression,
786 _row_index: usize,
787 ) -> Result<DataValue> {
788 self.evaluator.evaluate(expr, self.row_index)
789 }
790 }
791
792 let mut adapter = EvaluatorAdapter {
793 evaluator: self,
794 row_index,
795 };
796
797 return window_fn.compute(&context, row_index, args, &mut adapter);
799 }
800
801 let context = self.get_or_create_window_context(spec)?;
803
804 match name_upper.as_str() {
805 "LAG" => {
806 if args.is_empty() {
808 return Err(anyhow!("LAG requires at least 1 argument"));
809 }
810
811 let column = match &args[0] {
813 SqlExpression::Column(col) => col.clone(),
814 _ => return Err(anyhow!("LAG first argument must be a column")),
815 };
816
817 let offset = if args.len() > 1 {
819 match self.evaluate(&args[1], row_index)? {
820 DataValue::Integer(i) => i as i32,
821 _ => return Err(anyhow!("LAG offset must be an integer")),
822 }
823 } else {
824 1
825 };
826
827 Ok(context
829 .get_offset_value(row_index, -offset, &column.name)
830 .unwrap_or(DataValue::Null))
831 }
832 "LEAD" => {
833 if args.is_empty() {
835 return Err(anyhow!("LEAD requires at least 1 argument"));
836 }
837
838 let column = match &args[0] {
840 SqlExpression::Column(col) => col.clone(),
841 _ => return Err(anyhow!("LEAD first argument must be a column")),
842 };
843
844 let offset = if args.len() > 1 {
846 match self.evaluate(&args[1], row_index)? {
847 DataValue::Integer(i) => i as i32,
848 _ => return Err(anyhow!("LEAD offset must be an integer")),
849 }
850 } else {
851 1
852 };
853
854 Ok(context
856 .get_offset_value(row_index, offset, &column.name)
857 .unwrap_or(DataValue::Null))
858 }
859 "ROW_NUMBER" => {
860 Ok(DataValue::Integer(context.get_row_number(row_index) as i64))
862 }
863 "FIRST_VALUE" => {
864 if args.is_empty() {
866 return Err(anyhow!("FIRST_VALUE requires 1 argument"));
867 }
868
869 let column = match &args[0] {
870 SqlExpression::Column(col) => col.clone(),
871 _ => return Err(anyhow!("FIRST_VALUE argument must be a column")),
872 };
873
874 if context.has_frame() {
876 Ok(context
877 .get_frame_first_value(row_index, &column.name)
878 .unwrap_or(DataValue::Null))
879 } else {
880 Ok(context
881 .get_first_value(row_index, &column.name)
882 .unwrap_or(DataValue::Null))
883 }
884 }
885 "LAST_VALUE" => {
886 if args.is_empty() {
888 return Err(anyhow!("LAST_VALUE requires 1 argument"));
889 }
890
891 let column = match &args[0] {
892 SqlExpression::Column(col) => col.clone(),
893 _ => return Err(anyhow!("LAST_VALUE argument must be a column")),
894 };
895
896 if context.has_frame() {
898 Ok(context
899 .get_frame_last_value(row_index, &column.name)
900 .unwrap_or(DataValue::Null))
901 } else {
902 Ok(context
903 .get_last_value(row_index, &column.name)
904 .unwrap_or(DataValue::Null))
905 }
906 }
907 "SUM" => {
908 if args.is_empty() {
910 return Err(anyhow!("SUM requires 1 argument"));
911 }
912
913 let column = match &args[0] {
914 SqlExpression::Column(col) => col.clone(),
915 _ => return Err(anyhow!("SUM argument must be a column")),
916 };
917
918 if context.has_frame() {
920 Ok(context
921 .get_frame_sum(row_index, &column.name)
922 .unwrap_or(DataValue::Null))
923 } else {
924 Ok(context
925 .get_partition_sum(row_index, &column.name)
926 .unwrap_or(DataValue::Null))
927 }
928 }
929 "AVG" => {
930 if args.is_empty() {
932 return Err(anyhow!("AVG requires 1 argument"));
933 }
934
935 let column = match &args[0] {
936 SqlExpression::Column(col) => col.clone(),
937 _ => return Err(anyhow!("AVG argument must be a column")),
938 };
939
940 Ok(context
941 .get_frame_avg(row_index, &column.name)
942 .unwrap_or(DataValue::Null))
943 }
944 "STDDEV" | "STDEV" => {
945 if args.is_empty() {
947 return Err(anyhow!("STDDEV requires 1 argument"));
948 }
949
950 let column = match &args[0] {
951 SqlExpression::Column(col) => col.clone(),
952 _ => return Err(anyhow!("STDDEV argument must be a column")),
953 };
954
955 Ok(context
956 .get_frame_stddev(row_index, &column.name)
957 .unwrap_or(DataValue::Null))
958 }
959 "VARIANCE" | "VAR" => {
960 if args.is_empty() {
962 return Err(anyhow!("VARIANCE requires 1 argument"));
963 }
964
965 let column = match &args[0] {
966 SqlExpression::Column(col) => col.clone(),
967 _ => return Err(anyhow!("VARIANCE argument must be a column")),
968 };
969
970 Ok(context
971 .get_frame_variance(row_index, &column.name)
972 .unwrap_or(DataValue::Null))
973 }
974 "MIN" => {
975 if args.is_empty() {
977 return Err(anyhow!("MIN requires 1 argument"));
978 }
979
980 let column = match &args[0] {
981 SqlExpression::Column(col) => col.clone(),
982 _ => return Err(anyhow!("MIN argument must be a column")),
983 };
984
985 let frame_rows = context.get_frame_rows(row_index);
986 if frame_rows.is_empty() {
987 return Ok(DataValue::Null);
988 }
989
990 let source_table = context.source();
991 let col_idx = source_table
992 .get_column_index(&column.name)
993 .ok_or_else(|| anyhow!("Column '{}' not found", column.name))?;
994
995 let mut min_value: Option<DataValue> = None;
996 for &row_idx in &frame_rows {
997 if let Some(value) = source_table.get_value(row_idx, col_idx) {
998 if !matches!(value, DataValue::Null) {
999 match &min_value {
1000 None => min_value = Some(value.clone()),
1001 Some(current_min) => {
1002 if value < current_min {
1003 min_value = Some(value.clone());
1004 }
1005 }
1006 }
1007 }
1008 }
1009 }
1010
1011 Ok(min_value.unwrap_or(DataValue::Null))
1012 }
1013 "MAX" => {
1014 if args.is_empty() {
1016 return Err(anyhow!("MAX requires 1 argument"));
1017 }
1018
1019 let column = match &args[0] {
1020 SqlExpression::Column(col) => col.clone(),
1021 _ => return Err(anyhow!("MAX argument must be a column")),
1022 };
1023
1024 let frame_rows = context.get_frame_rows(row_index);
1025 if frame_rows.is_empty() {
1026 return Ok(DataValue::Null);
1027 }
1028
1029 let source_table = context.source();
1030 let col_idx = source_table
1031 .get_column_index(&column.name)
1032 .ok_or_else(|| anyhow!("Column '{}' not found", column.name))?;
1033
1034 let mut max_value: Option<DataValue> = None;
1035 for &row_idx in &frame_rows {
1036 if let Some(value) = source_table.get_value(row_idx, col_idx) {
1037 if !matches!(value, DataValue::Null) {
1038 match &max_value {
1039 None => max_value = Some(value.clone()),
1040 Some(current_max) => {
1041 if value > current_max {
1042 max_value = Some(value.clone());
1043 }
1044 }
1045 }
1046 }
1047 }
1048 }
1049
1050 Ok(max_value.unwrap_or(DataValue::Null))
1051 }
1052 "COUNT" => {
1053 if args.is_empty() {
1057 if context.has_frame() {
1059 Ok(context
1060 .get_frame_count(row_index, None)
1061 .unwrap_or(DataValue::Null))
1062 } else {
1063 Ok(context
1064 .get_partition_count(row_index, None)
1065 .unwrap_or(DataValue::Null))
1066 }
1067 } else {
1068 let column = match &args[0] {
1070 SqlExpression::Column(col) => {
1071 if col.name == "*" {
1072 if context.has_frame() {
1074 return Ok(context
1075 .get_frame_count(row_index, None)
1076 .unwrap_or(DataValue::Null));
1077 } else {
1078 return Ok(context
1079 .get_partition_count(row_index, None)
1080 .unwrap_or(DataValue::Null));
1081 }
1082 }
1083 col.clone()
1084 }
1085 SqlExpression::StringLiteral(s) if s == "*" => {
1086 if context.has_frame() {
1088 return Ok(context
1089 .get_frame_count(row_index, None)
1090 .unwrap_or(DataValue::Null));
1091 } else {
1092 return Ok(context
1093 .get_partition_count(row_index, None)
1094 .unwrap_or(DataValue::Null));
1095 }
1096 }
1097 _ => return Err(anyhow!("COUNT argument must be a column or *")),
1098 };
1099
1100 if context.has_frame() {
1102 Ok(context
1103 .get_frame_count(row_index, Some(&column.name))
1104 .unwrap_or(DataValue::Null))
1105 } else {
1106 Ok(context
1107 .get_partition_count(row_index, Some(&column.name))
1108 .unwrap_or(DataValue::Null))
1109 }
1110 }
1111 }
1112 _ => Err(anyhow!("Unknown window function: {}", name)),
1113 }
1114 }
1115
1116 fn evaluate_method_call(
1118 &mut self,
1119 object: &str,
1120 method: &str,
1121 args: &[SqlExpression],
1122 row_index: usize,
1123 ) -> Result<DataValue> {
1124 let col_index = self.table.get_column_index(object).ok_or_else(|| {
1126 let suggestion = self.find_similar_column(object);
1127 match suggestion {
1128 Some(similar) => {
1129 anyhow!("Column '{}' not found. Did you mean '{}'?", object, similar)
1130 }
1131 None => anyhow!("Column '{}' not found", object),
1132 }
1133 })?;
1134
1135 let cell_value = self.table.get_value(row_index, col_index).cloned();
1136
1137 self.evaluate_method_on_value(
1138 &cell_value.unwrap_or(DataValue::Null),
1139 method,
1140 args,
1141 row_index,
1142 )
1143 }
1144
1145 fn evaluate_method_on_value(
1147 &mut self,
1148 value: &DataValue,
1149 method: &str,
1150 args: &[SqlExpression],
1151 row_index: usize,
1152 ) -> Result<DataValue> {
1153 let function_name = match method.to_lowercase().as_str() {
1158 "trim" => "TRIM",
1159 "trimstart" | "trimbegin" => "TRIMSTART",
1160 "trimend" => "TRIMEND",
1161 "length" | "len" => "LENGTH",
1162 "contains" => "CONTAINS",
1163 "startswith" => "STARTSWITH",
1164 "endswith" => "ENDSWITH",
1165 "indexof" => "INDEXOF",
1166 _ => method, };
1168
1169 if self.function_registry.get(function_name).is_some() {
1171 debug!(
1172 "Proxying method '{}' through function registry as '{}'",
1173 method, function_name
1174 );
1175
1176 let mut func_args = vec![value.clone()];
1178
1179 for arg in args {
1181 func_args.push(self.evaluate(arg, row_index)?);
1182 }
1183
1184 let func = self.function_registry.get(function_name).unwrap();
1186 return func.evaluate(&func_args);
1187 }
1188
1189 Err(anyhow!(
1192 "Method '{}' not found. It should be registered in the function registry.",
1193 method
1194 ))
1195 }
1196
1197 fn evaluate_case_expression(
1199 &mut self,
1200 when_branches: &[crate::sql::recursive_parser::WhenBranch],
1201 else_branch: &Option<Box<SqlExpression>>,
1202 row_index: usize,
1203 ) -> Result<DataValue> {
1204 debug!(
1205 "ArithmeticEvaluator: evaluating CASE expression for row {}",
1206 row_index
1207 );
1208
1209 for branch in when_branches {
1211 let condition_result = self.evaluate_condition_as_bool(&branch.condition, row_index)?;
1213
1214 if condition_result {
1215 debug!("CASE: WHEN condition matched, evaluating result expression");
1216 return self.evaluate(&branch.result, row_index);
1217 }
1218 }
1219
1220 if let Some(else_expr) = else_branch {
1222 debug!("CASE: No WHEN matched, evaluating ELSE expression");
1223 self.evaluate(else_expr, row_index)
1224 } else {
1225 debug!("CASE: No WHEN matched and no ELSE, returning NULL");
1226 Ok(DataValue::Null)
1227 }
1228 }
1229
1230 fn evaluate_simple_case_expression(
1232 &mut self,
1233 expr: &Box<SqlExpression>,
1234 when_branches: &[crate::sql::parser::ast::SimpleWhenBranch],
1235 else_branch: &Option<Box<SqlExpression>>,
1236 row_index: usize,
1237 ) -> Result<DataValue> {
1238 debug!(
1239 "ArithmeticEvaluator: evaluating simple CASE expression for row {}",
1240 row_index
1241 );
1242
1243 let case_value = self.evaluate(expr, row_index)?;
1245 debug!("Simple CASE: evaluated expression to {:?}", case_value);
1246
1247 for branch in when_branches {
1249 let when_value = self.evaluate(&branch.value, row_index)?;
1251
1252 if self.values_equal(&case_value, &when_value)? {
1254 debug!("Simple CASE: WHEN value matched, evaluating result expression");
1255 return self.evaluate(&branch.result, row_index);
1256 }
1257 }
1258
1259 if let Some(else_expr) = else_branch {
1261 debug!("Simple CASE: No WHEN matched, evaluating ELSE expression");
1262 self.evaluate(else_expr, row_index)
1263 } else {
1264 debug!("Simple CASE: No WHEN matched and no ELSE, returning NULL");
1265 Ok(DataValue::Null)
1266 }
1267 }
1268
1269 fn values_equal(&self, left: &DataValue, right: &DataValue) -> Result<bool> {
1271 match (left, right) {
1272 (DataValue::Null, DataValue::Null) => Ok(true),
1273 (DataValue::Null, _) | (_, DataValue::Null) => Ok(false),
1274 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(a == b),
1275 (DataValue::Float(a), DataValue::Float(b)) => Ok((a - b).abs() < f64::EPSILON),
1276 (DataValue::String(a), DataValue::String(b)) => Ok(a == b),
1277 (DataValue::Boolean(a), DataValue::Boolean(b)) => Ok(a == b),
1278 (DataValue::DateTime(a), DataValue::DateTime(b)) => Ok(a == b),
1279 (DataValue::Integer(a), DataValue::Float(b)) => {
1281 Ok((*a as f64 - b).abs() < f64::EPSILON)
1282 }
1283 (DataValue::Float(a), DataValue::Integer(b)) => {
1284 Ok((a - *b as f64).abs() < f64::EPSILON)
1285 }
1286 _ => Ok(false),
1287 }
1288 }
1289
1290 fn evaluate_condition_as_bool(
1292 &mut self,
1293 expr: &SqlExpression,
1294 row_index: usize,
1295 ) -> Result<bool> {
1296 let value = self.evaluate(expr, row_index)?;
1297
1298 match value {
1299 DataValue::Boolean(b) => Ok(b),
1300 DataValue::Integer(i) => Ok(i != 0),
1301 DataValue::Float(f) => Ok(f != 0.0),
1302 DataValue::Null => Ok(false),
1303 DataValue::String(s) => Ok(!s.is_empty()),
1304 DataValue::InternedString(s) => Ok(!s.is_empty()),
1305 _ => Ok(true), }
1307 }
1308}
1309
1310#[cfg(test)]
1311mod tests {
1312 use super::*;
1313 use crate::data::datatable::{DataColumn, DataRow};
1314
1315 fn create_test_table() -> DataTable {
1316 let mut table = DataTable::new("test");
1317 table.add_column(DataColumn::new("a"));
1318 table.add_column(DataColumn::new("b"));
1319 table.add_column(DataColumn::new("c"));
1320
1321 table
1322 .add_row(DataRow::new(vec![
1323 DataValue::Integer(10),
1324 DataValue::Float(2.5),
1325 DataValue::Integer(4),
1326 ]))
1327 .unwrap();
1328
1329 table
1330 }
1331
1332 #[test]
1333 fn test_evaluate_column() {
1334 let table = create_test_table();
1335 let mut evaluator = ArithmeticEvaluator::new(&table);
1336
1337 let expr = SqlExpression::Column(ColumnRef::unquoted("a".to_string()));
1338 let result = evaluator.evaluate(&expr, 0).unwrap();
1339 assert_eq!(result, DataValue::Integer(10));
1340 }
1341
1342 #[test]
1343 fn test_evaluate_number_literal() {
1344 let table = create_test_table();
1345 let mut evaluator = ArithmeticEvaluator::new(&table);
1346
1347 let expr = SqlExpression::NumberLiteral("42".to_string());
1348 let result = evaluator.evaluate(&expr, 0).unwrap();
1349 assert_eq!(result, DataValue::Integer(42));
1350
1351 let expr = SqlExpression::NumberLiteral("3.14".to_string());
1352 let result = evaluator.evaluate(&expr, 0).unwrap();
1353 assert_eq!(result, DataValue::Float(3.14));
1354 }
1355
1356 #[test]
1357 fn test_add_values() {
1358 let table = create_test_table();
1359 let mut evaluator = ArithmeticEvaluator::new(&table);
1360
1361 let result = evaluator
1363 .add_values(&DataValue::Integer(5), &DataValue::Integer(3))
1364 .unwrap();
1365 assert_eq!(result, DataValue::Integer(8));
1366
1367 let result = evaluator
1369 .add_values(&DataValue::Integer(5), &DataValue::Float(2.5))
1370 .unwrap();
1371 assert_eq!(result, DataValue::Float(7.5));
1372 }
1373
1374 #[test]
1375 fn test_multiply_values() {
1376 let table = create_test_table();
1377 let mut evaluator = ArithmeticEvaluator::new(&table);
1378
1379 let result = evaluator
1381 .multiply_values(&DataValue::Integer(4), &DataValue::Float(2.5))
1382 .unwrap();
1383 assert_eq!(result, DataValue::Float(10.0));
1384 }
1385
1386 #[test]
1387 fn test_divide_values() {
1388 let table = create_test_table();
1389 let mut evaluator = ArithmeticEvaluator::new(&table);
1390
1391 let result = evaluator
1393 .divide_values(&DataValue::Integer(10), &DataValue::Integer(2))
1394 .unwrap();
1395 assert_eq!(result, DataValue::Integer(5));
1396
1397 let result = evaluator
1399 .divide_values(&DataValue::Integer(10), &DataValue::Integer(3))
1400 .unwrap();
1401 assert_eq!(result, DataValue::Float(10.0 / 3.0));
1402 }
1403
1404 #[test]
1405 fn test_division_by_zero() {
1406 let table = create_test_table();
1407 let mut evaluator = ArithmeticEvaluator::new(&table);
1408
1409 let result = evaluator.divide_values(&DataValue::Integer(10), &DataValue::Integer(0));
1410 assert!(result.is_err());
1411 assert!(result.unwrap_err().to_string().contains("Division by zero"));
1412 }
1413
1414 #[test]
1415 fn test_binary_op_expression() {
1416 let table = create_test_table();
1417 let mut evaluator = ArithmeticEvaluator::new(&table);
1418
1419 let expr = SqlExpression::BinaryOp {
1421 left: Box::new(SqlExpression::Column(ColumnRef::unquoted("a".to_string()))),
1422 op: "*".to_string(),
1423 right: Box::new(SqlExpression::Column(ColumnRef::unquoted("b".to_string()))),
1424 };
1425
1426 let result = evaluator.evaluate(&expr, 0).unwrap();
1427 assert_eq!(result, DataValue::Float(25.0));
1428 }
1429}