1use crate::config::global::get_date_notation;
2use crate::data::data_view::DataView;
3use crate::data::datatable::{DataTable, DataValue};
4use crate::data::value_comparisons::compare_with_op;
5use crate::sql::aggregate_functions::AggregateFunctionRegistry; use crate::sql::aggregates::AggregateRegistry; use crate::sql::functions::FunctionRegistry;
8use crate::sql::parser::ast::{ColumnRef, WindowSpec};
9use crate::sql::recursive_parser::SqlExpression;
10use crate::sql::window_context::WindowContext;
11use crate::sql::window_functions::{ExpressionEvaluator, WindowFunctionRegistry};
12use anyhow::{anyhow, Result};
13use std::collections::{HashMap, HashSet};
14use std::sync::Arc;
15use std::time::Instant;
16use tracing::{debug, info};
17
18pub struct ArithmeticEvaluator<'a> {
21 table: &'a DataTable,
22 _date_notation: String,
23 function_registry: Arc<FunctionRegistry>,
24 aggregate_registry: Arc<AggregateRegistry>, new_aggregate_registry: Arc<AggregateFunctionRegistry>, window_function_registry: Arc<WindowFunctionRegistry>,
27 visible_rows: Option<Vec<usize>>, window_contexts: HashMap<u64, Arc<WindowContext>>, table_aliases: HashMap<String, String>, }
31
32impl<'a> ArithmeticEvaluator<'a> {
33 #[must_use]
34 pub fn new(table: &'a DataTable) -> Self {
35 Self {
36 table,
37 _date_notation: get_date_notation(),
38 function_registry: Arc::new(FunctionRegistry::new()),
39 aggregate_registry: Arc::new(AggregateRegistry::new()),
40 new_aggregate_registry: Arc::new(AggregateFunctionRegistry::new()),
41 window_function_registry: Arc::new(WindowFunctionRegistry::new()),
42 visible_rows: None,
43 window_contexts: HashMap::new(),
44 table_aliases: HashMap::new(),
45 }
46 }
47
48 #[must_use]
49 pub fn with_date_notation(table: &'a DataTable, date_notation: String) -> Self {
50 Self {
51 table,
52 _date_notation: date_notation,
53 function_registry: Arc::new(FunctionRegistry::new()),
54 aggregate_registry: Arc::new(AggregateRegistry::new()),
55 new_aggregate_registry: Arc::new(AggregateFunctionRegistry::new()),
56 window_function_registry: Arc::new(WindowFunctionRegistry::new()),
57 visible_rows: None,
58 window_contexts: HashMap::new(),
59 table_aliases: HashMap::new(),
60 }
61 }
62
63 #[must_use]
65 pub fn with_visible_rows(mut self, rows: Vec<usize>) -> Self {
66 self.visible_rows = Some(rows);
67 self
68 }
69
70 #[must_use]
72 pub fn with_table_aliases(mut self, aliases: HashMap<String, String>) -> Self {
73 self.table_aliases = aliases;
74 self
75 }
76
77 #[must_use]
78 pub fn with_date_notation_and_registry(
79 table: &'a DataTable,
80 date_notation: String,
81 function_registry: Arc<FunctionRegistry>,
82 ) -> Self {
83 Self {
84 table,
85 _date_notation: date_notation,
86 function_registry,
87 aggregate_registry: Arc::new(AggregateRegistry::new()),
88 new_aggregate_registry: Arc::new(AggregateFunctionRegistry::new()),
89 window_function_registry: Arc::new(WindowFunctionRegistry::new()),
90 visible_rows: None,
91 window_contexts: HashMap::new(),
92 table_aliases: HashMap::new(),
93 }
94 }
95
96 fn find_similar_column(&self, name: &str) -> Option<String> {
98 let columns = self.table.column_names();
99 let mut best_match: Option<(String, usize)> = None;
100
101 for col in columns {
102 let distance = self.edit_distance(&col.to_lowercase(), &name.to_lowercase());
103 let max_distance = if name.len() > 10 { 3 } else { 2 };
106 if distance <= max_distance {
107 match &best_match {
108 None => best_match = Some((col, distance)),
109 Some((_, best_dist)) if distance < *best_dist => {
110 best_match = Some((col, distance));
111 }
112 _ => {}
113 }
114 }
115 }
116
117 best_match.map(|(name, _)| name)
118 }
119
120 fn edit_distance(&self, s1: &str, s2: &str) -> usize {
122 crate::sql::functions::string_methods::EditDistanceFunction::calculate_edit_distance(s1, s2)
124 }
125
126 pub fn evaluate(&mut self, expr: &SqlExpression, row_index: usize) -> Result<DataValue> {
128 debug!(
129 "ArithmeticEvaluator: evaluating {:?} for row {}",
130 expr, row_index
131 );
132
133 match expr {
134 SqlExpression::Column(column_ref) => self.evaluate_column_ref(column_ref, row_index),
135 SqlExpression::StringLiteral(s) => Ok(DataValue::String(s.clone())),
136 SqlExpression::BooleanLiteral(b) => Ok(DataValue::Boolean(*b)),
137 SqlExpression::NumberLiteral(n) => self.evaluate_number_literal(n),
138 SqlExpression::Null => Ok(DataValue::Null),
139 SqlExpression::BinaryOp { left, op, right } => {
140 self.evaluate_binary_op(left, op, right, row_index)
141 }
142 SqlExpression::FunctionCall {
143 name,
144 args,
145 distinct,
146 } => self.evaluate_function_with_distinct(name, args, *distinct, row_index),
147 SqlExpression::WindowFunction {
148 name,
149 args,
150 window_spec,
151 } => self.evaluate_window_function(name, args, window_spec, row_index),
152 SqlExpression::MethodCall {
153 object,
154 method,
155 args,
156 } => self.evaluate_method_call(object, method, args, row_index),
157 SqlExpression::ChainedMethodCall { base, method, args } => {
158 let base_value = self.evaluate(base, row_index)?;
160 self.evaluate_method_on_value(&base_value, method, args, row_index)
161 }
162 SqlExpression::Between { expr, lower, upper } => {
163 let val = self.evaluate(expr, row_index)?;
164 let lo = self.evaluate(lower, row_index)?;
165 let hi = self.evaluate(upper, row_index)?;
166 let ge = compare_with_op(&val, &lo, ">=", false);
167 let le = compare_with_op(&val, &hi, "<=", false);
168 Ok(DataValue::Boolean(ge && le))
169 }
170 SqlExpression::CaseExpression {
171 when_branches,
172 else_branch,
173 } => self.evaluate_case_expression(when_branches, else_branch, row_index),
174 SqlExpression::SimpleCaseExpression {
175 expr,
176 when_branches,
177 else_branch,
178 } => self.evaluate_simple_case_expression(expr, when_branches, else_branch, row_index),
179 SqlExpression::DateTimeConstructor {
180 year,
181 month,
182 day,
183 hour,
184 minute,
185 second,
186 } => self.evaluate_datetime_constructor(*year, *month, *day, *hour, *minute, *second),
187 SqlExpression::DateTimeToday {
188 hour,
189 minute,
190 second,
191 } => self.evaluate_datetime_today(*hour, *minute, *second),
192 _ => Err(anyhow!(
193 "Unsupported expression type for arithmetic evaluation: {:?}",
194 expr
195 )),
196 }
197 }
198
199 fn evaluate_column_ref(&self, column_ref: &ColumnRef, row_index: usize) -> Result<DataValue> {
201 if let Some(table_prefix) = &column_ref.table_prefix {
202 let actual_table = self
204 .table_aliases
205 .get(table_prefix)
206 .map(|s| s.as_str())
207 .unwrap_or(table_prefix);
208
209 let qualified_name = format!("{}.{}", actual_table, column_ref.name);
211
212 if let Some(col_idx) = self.table.find_column_by_qualified_name(&qualified_name) {
213 debug!(
214 "Resolved {}.{} -> '{}' at index {}",
215 table_prefix, column_ref.name, qualified_name, col_idx
216 );
217 return self
218 .table
219 .get_value(row_index, col_idx)
220 .ok_or_else(|| anyhow!("Row {} out of bounds", row_index))
221 .map(|v| v.clone());
222 }
223
224 if let Some(col_idx) = self.table.get_column_index(&column_ref.name) {
226 debug!(
227 "Resolved {}.{} -> unqualified '{}' at index {}",
228 table_prefix, column_ref.name, column_ref.name, col_idx
229 );
230 return self
231 .table
232 .get_value(row_index, col_idx)
233 .ok_or_else(|| anyhow!("Row {} out of bounds", row_index))
234 .map(|v| v.clone());
235 }
236
237 Err(anyhow!(
239 "Column '{}' not found. Table '{}' may not support qualified column names",
240 qualified_name,
241 actual_table
242 ))
243 } else {
244 self.evaluate_column(&column_ref.name, row_index)
246 }
247 }
248
249 fn evaluate_column(&self, column_name: &str, row_index: usize) -> Result<DataValue> {
251 let resolved_column = if column_name.contains('.') {
253 if let Some(dot_pos) = column_name.rfind('.') {
255 let _table_or_alias = &column_name[..dot_pos];
256 let col_name = &column_name[dot_pos + 1..];
257
258 debug!(
261 "Resolving qualified column: {} -> {}",
262 column_name, col_name
263 );
264 col_name.to_string()
265 } else {
266 column_name.to_string()
267 }
268 } else {
269 column_name.to_string()
270 };
271
272 let col_index = if let Some(idx) = self.table.get_column_index(&resolved_column) {
273 idx
274 } else if resolved_column != column_name {
275 if let Some(idx) = self.table.get_column_index(column_name) {
277 idx
278 } else {
279 let suggestion = self.find_similar_column(&resolved_column);
280 return Err(match suggestion {
281 Some(similar) => anyhow!(
282 "Column '{}' not found. Did you mean '{}'?",
283 column_name,
284 similar
285 ),
286 None => anyhow!("Column '{}' not found", column_name),
287 });
288 }
289 } else {
290 let suggestion = self.find_similar_column(&resolved_column);
291 return Err(match suggestion {
292 Some(similar) => anyhow!(
293 "Column '{}' not found. Did you mean '{}'?",
294 column_name,
295 similar
296 ),
297 None => anyhow!("Column '{}' not found", column_name),
298 });
299 };
300
301 if row_index >= self.table.row_count() {
302 return Err(anyhow!("Row index {} out of bounds", row_index));
303 }
304
305 let row = self
306 .table
307 .get_row(row_index)
308 .ok_or_else(|| anyhow!("Row {} not found", row_index))?;
309
310 let value = row
311 .get(col_index)
312 .ok_or_else(|| anyhow!("Column index {} out of bounds for row", col_index))?;
313
314 Ok(value.clone())
315 }
316
317 fn evaluate_number_literal(&self, number_str: &str) -> Result<DataValue> {
319 if let Ok(int_val) = number_str.parse::<i64>() {
321 return Ok(DataValue::Integer(int_val));
322 }
323
324 if let Ok(float_val) = number_str.parse::<f64>() {
326 return Ok(DataValue::Float(float_val));
327 }
328
329 Err(anyhow!("Invalid number literal: {}", number_str))
330 }
331
332 fn evaluate_binary_op(
334 &mut self,
335 left: &SqlExpression,
336 op: &str,
337 right: &SqlExpression,
338 row_index: usize,
339 ) -> Result<DataValue> {
340 let left_val = self.evaluate(left, row_index)?;
341 let right_val = self.evaluate(right, row_index)?;
342
343 debug!(
344 "ArithmeticEvaluator: {} {} {}",
345 self.format_value(&left_val),
346 op,
347 self.format_value(&right_val)
348 );
349
350 match op {
351 "+" => self.add_values(&left_val, &right_val),
352 "-" => self.subtract_values(&left_val, &right_val),
353 "*" => self.multiply_values(&left_val, &right_val),
354 "/" => self.divide_values(&left_val, &right_val),
355 "%" => {
356 let args = vec![left.clone(), right.clone()];
358 self.evaluate_function("MOD", &args, row_index)
359 }
360 ">" | "<" | ">=" | "<=" | "=" | "!=" | "<>" => {
363 let result = compare_with_op(&left_val, &right_val, op, false);
364 Ok(DataValue::Boolean(result))
365 }
366 "IS NULL" => Ok(DataValue::Boolean(matches!(left_val, DataValue::Null))),
368 "IS NOT NULL" => Ok(DataValue::Boolean(!matches!(left_val, DataValue::Null))),
369 "AND" => {
371 let left_bool = self.to_bool(&left_val)?;
372 let right_bool = self.to_bool(&right_val)?;
373 Ok(DataValue::Boolean(left_bool && right_bool))
374 }
375 "OR" => {
376 let left_bool = self.to_bool(&left_val)?;
377 let right_bool = self.to_bool(&right_val)?;
378 Ok(DataValue::Boolean(left_bool || right_bool))
379 }
380 "LIKE" => {
382 let text = self.value_to_string(&left_val);
383 let pattern = self.value_to_string(&right_val);
384 let matches = self.sql_like_match(&text, &pattern);
385 Ok(DataValue::Boolean(matches))
386 }
387 _ => Err(anyhow!("Unsupported arithmetic operator: {}", op)),
388 }
389 }
390
391 fn add_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
393 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
395 return Ok(DataValue::Null);
396 }
397
398 match (left, right) {
399 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a + b)),
400 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 + b)),
401 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a + *b as f64)),
402 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a + b)),
403 _ => Err(anyhow!("Cannot add {:?} and {:?}", left, right)),
404 }
405 }
406
407 fn subtract_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
409 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
411 return Ok(DataValue::Null);
412 }
413
414 match (left, right) {
415 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a - b)),
416 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 - b)),
417 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a - *b as f64)),
418 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a - b)),
419 _ => Err(anyhow!("Cannot subtract {:?} and {:?}", left, right)),
420 }
421 }
422
423 fn multiply_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
425 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
427 return Ok(DataValue::Null);
428 }
429
430 match (left, right) {
431 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a * b)),
432 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 * b)),
433 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a * *b as f64)),
434 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a * b)),
435 _ => Err(anyhow!("Cannot multiply {:?} and {:?}", left, right)),
436 }
437 }
438
439 fn divide_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
441 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
443 return Ok(DataValue::Null);
444 }
445
446 let is_zero = match right {
448 DataValue::Integer(0) => true,
449 DataValue::Float(f) if *f == 0.0 => true, _ => false,
451 };
452
453 if is_zero {
454 return Err(anyhow!("Division by zero"));
455 }
456
457 match (left, right) {
458 (DataValue::Integer(a), DataValue::Integer(b)) => {
459 if a % b == 0 {
461 Ok(DataValue::Integer(a / b))
462 } else {
463 Ok(DataValue::Float(*a as f64 / *b as f64))
464 }
465 }
466 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 / b)),
467 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a / *b as f64)),
468 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a / b)),
469 _ => Err(anyhow!("Cannot divide {:?} and {:?}", left, right)),
470 }
471 }
472
473 fn format_value(&self, value: &DataValue) -> String {
475 match value {
476 DataValue::Integer(i) => i.to_string(),
477 DataValue::Float(f) => f.to_string(),
478 DataValue::String(s) => format!("'{s}'"),
479 _ => format!("{value:?}"),
480 }
481 }
482
483 fn to_bool(&self, value: &DataValue) -> Result<bool> {
485 match value {
486 DataValue::Boolean(b) => Ok(*b),
487 DataValue::Integer(i) => Ok(*i != 0),
488 DataValue::Float(f) => Ok(*f != 0.0),
489 DataValue::Null => Ok(false),
490 _ => Err(anyhow!("Cannot convert {:?} to boolean", value)),
491 }
492 }
493
494 fn value_to_string(&self, value: &DataValue) -> String {
496 match value {
497 DataValue::String(s) => s.clone(),
498 DataValue::InternedString(s) => s.to_string(),
499 DataValue::Integer(i) => i.to_string(),
500 DataValue::Float(f) => f.to_string(),
501 DataValue::Boolean(b) => b.to_string(),
502 DataValue::DateTime(dt) => dt.to_string(),
503 DataValue::Vector(v) => {
504 let components: Vec<String> = v.iter().map(|f| f.to_string()).collect();
506 format!("[{}]", components.join(","))
507 }
508 DataValue::Null => String::new(),
509 }
510 }
511
512 fn sql_like_match(&self, text: &str, pattern: &str) -> bool {
515 let pattern_chars: Vec<char> = pattern.chars().collect();
516 let text_chars: Vec<char> = text.chars().collect();
517
518 self.like_match_recursive(&text_chars, 0, &pattern_chars, 0)
519 }
520
521 fn like_match_recursive(
523 &self,
524 text: &[char],
525 text_pos: usize,
526 pattern: &[char],
527 pattern_pos: usize,
528 ) -> bool {
529 if pattern_pos >= pattern.len() {
531 return text_pos >= text.len();
532 }
533
534 if pattern[pattern_pos] == '%' {
536 if self.like_match_recursive(text, text_pos, pattern, pattern_pos + 1) {
538 return true;
539 }
540 if text_pos < text.len() {
542 return self.like_match_recursive(text, text_pos + 1, pattern, pattern_pos);
543 }
544 return false;
545 }
546
547 if text_pos >= text.len() {
549 return false;
550 }
551
552 if pattern[pattern_pos] == '_' {
554 return self.like_match_recursive(text, text_pos + 1, pattern, pattern_pos + 1);
555 }
556
557 if text[text_pos] == pattern[pattern_pos] {
559 return self.like_match_recursive(text, text_pos + 1, pattern, pattern_pos + 1);
560 }
561
562 false
563 }
564
565 fn evaluate_function_with_distinct(
567 &mut self,
568 name: &str,
569 args: &[SqlExpression],
570 distinct: bool,
571 row_index: usize,
572 ) -> Result<DataValue> {
573 if distinct {
575 let name_upper = name.to_uppercase();
576
577 if self.aggregate_registry.is_aggregate(&name_upper)
579 || self.new_aggregate_registry.contains(&name_upper)
580 {
581 return self.evaluate_aggregate_with_distinct(&name_upper, args, row_index);
582 } else {
583 return Err(anyhow!(
584 "DISTINCT can only be used with aggregate functions"
585 ));
586 }
587 }
588
589 self.evaluate_function(name, args, row_index)
591 }
592
593 fn evaluate_aggregate_with_distinct(
594 &mut self,
595 name: &str,
596 args: &[SqlExpression],
597 _row_index: usize,
598 ) -> Result<DataValue> {
599 let name_upper = name.to_uppercase();
600
601 if self.new_aggregate_registry.get(&name_upper).is_some() {
603 let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
604 visible.clone()
605 } else {
606 (0..self.table.rows.len()).collect()
607 };
608
609 let mut vals = Vec::new();
611 for &row_idx in &rows_to_process {
612 if !args.is_empty() {
613 let value = self.evaluate(&args[0], row_idx)?;
614 vals.push(value);
615 }
616 }
617
618 let mut seen = HashSet::new();
620 let unique_values: Vec<_> = vals
621 .into_iter()
622 .filter(|v| {
623 let key = format!("{:?}", v);
624 seen.insert(key)
625 })
626 .collect();
627
628 let agg_func = self.new_aggregate_registry.get(&name_upper).unwrap();
630 let mut state = agg_func.create_state();
631
632 for value in &unique_values {
634 state.accumulate(value)?;
635 }
636
637 return Ok(state.finalize());
638 }
639
640 if self.aggregate_registry.get(&name_upper).is_some() {
642 let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
644 visible.clone()
645 } else {
646 (0..self.table.rows.len()).collect()
647 };
648
649 if name_upper == "STRING_AGG" && args.len() >= 2 {
651 let mut state = crate::sql::aggregates::AggregateState::StringAgg(
653 if args.len() >= 2 {
655 let separator = self.evaluate(&args[1], 0)?; match separator {
657 DataValue::String(s) => crate::sql::aggregates::StringAggState::new(&s),
658 DataValue::InternedString(s) => {
659 crate::sql::aggregates::StringAggState::new(&s)
660 }
661 _ => crate::sql::aggregates::StringAggState::new(","), }
663 } else {
664 crate::sql::aggregates::StringAggState::new(",")
665 },
666 );
667
668 let mut seen_values = HashSet::new();
671
672 for &row_idx in &rows_to_process {
673 let value = self.evaluate(&args[0], row_idx)?;
674
675 if !seen_values.insert(value.clone()) {
677 continue; }
679
680 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
682 agg_func.accumulate(&mut state, &value)?;
683 }
684
685 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
687 return Ok(agg_func.finalize(state));
688 }
689
690 let mut vals = Vec::new();
693 for &row_idx in &rows_to_process {
694 if !args.is_empty() {
695 let value = self.evaluate(&args[0], row_idx)?;
696 vals.push(value);
697 }
698 }
699
700 let mut seen = HashSet::new();
702 let mut unique_values = Vec::new();
703 for value in vals {
704 if seen.insert(value.clone()) {
705 unique_values.push(value);
706 }
707 }
708
709 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
711 let mut state = agg_func.init();
712
713 for value in &unique_values {
715 agg_func.accumulate(&mut state, value)?;
716 }
717
718 return Ok(agg_func.finalize(state));
719 }
720
721 Err(anyhow!("Unknown aggregate function: {}", name))
722 }
723
724 fn evaluate_function(
725 &mut self,
726 name: &str,
727 args: &[SqlExpression],
728 row_index: usize,
729 ) -> Result<DataValue> {
730 let name_upper = name.to_uppercase();
732
733 if self.new_aggregate_registry.get(&name_upper).is_some() {
735 let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
737 visible.clone()
738 } else {
739 (0..self.table.rows.len()).collect()
740 };
741
742 let agg_func = self.new_aggregate_registry.get(&name_upper).unwrap();
744 let mut state = agg_func.create_state();
745
746 if name_upper == "COUNT" || name_upper == "COUNT_STAR" {
748 if args.is_empty()
749 || (args.len() == 1
750 && matches!(&args[0], SqlExpression::Column(col) if col.name == "*"))
751 || (args.len() == 1
752 && matches!(&args[0], SqlExpression::StringLiteral(s) if s == "*"))
753 {
754 for _ in &rows_to_process {
756 state.accumulate(&DataValue::Integer(1))?;
757 }
758 } else {
759 for &row_idx in &rows_to_process {
761 let value = self.evaluate(&args[0], row_idx)?;
762 state.accumulate(&value)?;
763 }
764 }
765 } else {
766 if !args.is_empty() {
768 for &row_idx in &rows_to_process {
769 let value = self.evaluate(&args[0], row_idx)?;
770 state.accumulate(&value)?;
771 }
772 }
773 }
774
775 return Ok(state.finalize());
776 }
777
778 if self.aggregate_registry.get(&name_upper).is_some() {
780 let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
782 visible.clone()
783 } else {
784 (0..self.table.rows.len()).collect()
785 };
786
787 if name_upper == "STRING_AGG" && args.len() >= 2 {
789 let mut state = crate::sql::aggregates::AggregateState::StringAgg(
791 if args.len() >= 2 {
793 let separator = self.evaluate(&args[1], 0)?; match separator {
795 DataValue::String(s) => crate::sql::aggregates::StringAggState::new(&s),
796 DataValue::InternedString(s) => {
797 crate::sql::aggregates::StringAggState::new(&s)
798 }
799 _ => crate::sql::aggregates::StringAggState::new(","), }
801 } else {
802 crate::sql::aggregates::StringAggState::new(",")
803 },
804 );
805
806 for &row_idx in &rows_to_process {
808 let value = self.evaluate(&args[0], row_idx)?;
809 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
811 agg_func.accumulate(&mut state, &value)?;
812 }
813
814 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
816 return Ok(agg_func.finalize(state));
817 }
818
819 let values = if !args.is_empty()
821 && !(args.len() == 1
822 && matches!(&args[0], SqlExpression::Column(c) if c.name == "*"))
823 {
824 let mut vals = Vec::new();
826 for &row_idx in &rows_to_process {
827 let value = self.evaluate(&args[0], row_idx)?;
828 vals.push(value);
829 }
830 Some(vals)
831 } else {
832 None
833 };
834
835 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
837 let mut state = agg_func.init();
838
839 if let Some(values) = values {
840 for value in &values {
842 agg_func.accumulate(&mut state, value)?;
843 }
844 } else {
845 for _ in &rows_to_process {
847 agg_func.accumulate(&mut state, &DataValue::Integer(1))?;
848 }
849 }
850
851 return Ok(agg_func.finalize(state));
852 }
853
854 if self.function_registry.get(name).is_some() {
856 let mut evaluated_args = Vec::new();
858 for arg in args {
859 evaluated_args.push(self.evaluate(arg, row_index)?);
860 }
861
862 let func = self.function_registry.get(name).unwrap();
864 return func.evaluate(&evaluated_args);
865 }
866
867 Err(anyhow!("Unknown function: {}", name))
869 }
870
871 pub fn get_or_create_window_context(
874 &mut self,
875 spec: &WindowSpec,
876 ) -> Result<Arc<WindowContext>> {
877 let overall_start = Instant::now();
878
879 let key = spec.compute_hash();
881
882 if let Some(context) = self.window_contexts.get(&key) {
883 info!(
884 "WindowContext cache hit for spec (lookup: {:.2}μs)",
885 overall_start.elapsed().as_micros()
886 );
887 return Ok(Arc::clone(context));
888 }
889
890 info!("WindowContext cache miss - creating new context");
891 let dataview_start = Instant::now();
892
893 let data_view = if let Some(ref _visible_rows) = self.visible_rows {
895 let view = DataView::new(Arc::new(self.table.clone()));
897 view
900 } else {
901 DataView::new(Arc::new(self.table.clone()))
902 };
903
904 info!(
905 "DataView creation took {:.2}μs",
906 dataview_start.elapsed().as_micros()
907 );
908 let context_start = Instant::now();
909
910 let context = WindowContext::new_with_spec(Arc::new(data_view), spec.clone())?;
912
913 info!(
914 "WindowContext::new_with_spec took {:.2}ms (rows: {})",
915 context_start.elapsed().as_secs_f64() * 1000.0,
916 self.table.row_count()
917 );
918
919 let context = Arc::new(context);
920 self.window_contexts.insert(key, Arc::clone(&context));
921
922 info!(
923 "Total WindowContext creation (cache miss) took {:.2}ms",
924 overall_start.elapsed().as_secs_f64() * 1000.0
925 );
926
927 Ok(context)
928 }
929
930 fn evaluate_window_function(
932 &mut self,
933 name: &str,
934 args: &[SqlExpression],
935 spec: &WindowSpec,
936 row_index: usize,
937 ) -> Result<DataValue> {
938 let func_start = Instant::now();
939 let name_upper = name.to_uppercase();
940
941 debug!("Looking for window function {} in registry", name_upper);
943 if let Some(window_fn_arc) = self.window_function_registry.get(&name_upper) {
944 debug!("Found window function {} in registry", name_upper);
945
946 let window_fn = window_fn_arc.as_ref();
948
949 window_fn.validate_args(args)?;
951
952 let transformed_spec = window_fn.transform_window_spec(spec, args)?;
954
955 let context = self.get_or_create_window_context(&transformed_spec)?;
957
958 struct EvaluatorAdapter<'a, 'b> {
960 evaluator: &'a mut ArithmeticEvaluator<'b>,
961 row_index: usize,
962 }
963
964 impl<'a, 'b> ExpressionEvaluator for EvaluatorAdapter<'a, 'b> {
965 fn evaluate(
966 &mut self,
967 expr: &SqlExpression,
968 row_index: usize,
969 ) -> Result<DataValue> {
970 self.evaluator.evaluate(expr, row_index)
971 }
972 }
973
974 let mut adapter = EvaluatorAdapter {
975 evaluator: self,
976 row_index,
977 };
978
979 let compute_start = Instant::now();
980 let result = window_fn.compute(&context, row_index, args, &mut adapter);
982
983 info!(
984 "{} (registry) evaluation: total={:.2}μs, compute={:.2}μs",
985 name_upper,
986 func_start.elapsed().as_micros(),
987 compute_start.elapsed().as_micros()
988 );
989
990 return result;
991 }
992
993 let context_start = Instant::now();
995 let context = self.get_or_create_window_context(spec)?;
996 let context_time = context_start.elapsed();
997
998 let eval_start = Instant::now();
999
1000 let result = match name_upper.as_str() {
1001 "LAG" => {
1002 if args.is_empty() {
1004 return Err(anyhow!("LAG requires at least 1 argument"));
1005 }
1006
1007 let column = match &args[0] {
1009 SqlExpression::Column(col) => col.clone(),
1010 _ => return Err(anyhow!("LAG first argument must be a column")),
1011 };
1012
1013 let offset = if args.len() > 1 {
1015 match self.evaluate(&args[1], row_index)? {
1016 DataValue::Integer(i) => i as i32,
1017 _ => return Err(anyhow!("LAG offset must be an integer")),
1018 }
1019 } else {
1020 1
1021 };
1022
1023 let offset_start = Instant::now();
1024 let value = context
1026 .get_offset_value(row_index, -offset, &column.name)
1027 .unwrap_or(DataValue::Null);
1028
1029 debug!(
1030 "LAG offset access took {:.2}μs (offset={})",
1031 offset_start.elapsed().as_micros(),
1032 offset
1033 );
1034
1035 Ok(value)
1036 }
1037 "LEAD" => {
1038 if args.is_empty() {
1040 return Err(anyhow!("LEAD requires at least 1 argument"));
1041 }
1042
1043 let column = match &args[0] {
1045 SqlExpression::Column(col) => col.clone(),
1046 _ => return Err(anyhow!("LEAD first argument must be a column")),
1047 };
1048
1049 let offset = if args.len() > 1 {
1051 match self.evaluate(&args[1], row_index)? {
1052 DataValue::Integer(i) => i as i32,
1053 _ => return Err(anyhow!("LEAD offset must be an integer")),
1054 }
1055 } else {
1056 1
1057 };
1058
1059 let offset_start = Instant::now();
1060 let value = context
1062 .get_offset_value(row_index, offset, &column.name)
1063 .unwrap_or(DataValue::Null);
1064
1065 debug!(
1066 "LEAD offset access took {:.2}μs (offset={})",
1067 offset_start.elapsed().as_micros(),
1068 offset
1069 );
1070
1071 Ok(value)
1072 }
1073 "ROW_NUMBER" => {
1074 Ok(DataValue::Integer(context.get_row_number(row_index) as i64))
1076 }
1077 "RANK" => {
1078 Ok(DataValue::Integer(context.get_rank(row_index)))
1080 }
1081 "DENSE_RANK" => {
1082 Ok(DataValue::Integer(context.get_dense_rank(row_index)))
1084 }
1085 "FIRST_VALUE" => {
1086 if args.is_empty() {
1088 return Err(anyhow!("FIRST_VALUE requires 1 argument"));
1089 }
1090
1091 let column = match &args[0] {
1092 SqlExpression::Column(col) => col.clone(),
1093 _ => return Err(anyhow!("FIRST_VALUE argument must be a column")),
1094 };
1095
1096 if context.has_frame() {
1098 Ok(context
1099 .get_frame_first_value(row_index, &column.name)
1100 .unwrap_or(DataValue::Null))
1101 } else {
1102 Ok(context
1103 .get_first_value(row_index, &column.name)
1104 .unwrap_or(DataValue::Null))
1105 }
1106 }
1107 "LAST_VALUE" => {
1108 if args.is_empty() {
1110 return Err(anyhow!("LAST_VALUE requires 1 argument"));
1111 }
1112
1113 let column = match &args[0] {
1114 SqlExpression::Column(col) => col.clone(),
1115 _ => return Err(anyhow!("LAST_VALUE argument must be a column")),
1116 };
1117
1118 if context.has_frame() {
1120 Ok(context
1121 .get_frame_last_value(row_index, &column.name)
1122 .unwrap_or(DataValue::Null))
1123 } else {
1124 Ok(context
1125 .get_last_value(row_index, &column.name)
1126 .unwrap_or(DataValue::Null))
1127 }
1128 }
1129 "SUM" => {
1130 if args.is_empty() {
1132 return Err(anyhow!("SUM requires 1 argument"));
1133 }
1134
1135 let column = match &args[0] {
1136 SqlExpression::Column(col) => col.clone(),
1137 _ => return Err(anyhow!("SUM argument must be a column")),
1138 };
1139
1140 if context.has_frame() {
1142 Ok(context
1143 .get_frame_sum(row_index, &column.name)
1144 .unwrap_or(DataValue::Null))
1145 } else {
1146 Ok(context
1147 .get_partition_sum(row_index, &column.name)
1148 .unwrap_or(DataValue::Null))
1149 }
1150 }
1151 "AVG" => {
1152 if args.is_empty() {
1154 return Err(anyhow!("AVG requires 1 argument"));
1155 }
1156
1157 let column = match &args[0] {
1158 SqlExpression::Column(col) => col.clone(),
1159 _ => return Err(anyhow!("AVG argument must be a column")),
1160 };
1161
1162 if context.has_frame() {
1164 Ok(context
1165 .get_frame_avg(row_index, &column.name)
1166 .unwrap_or(DataValue::Null))
1167 } else {
1168 Ok(context
1169 .get_partition_avg(row_index, &column.name)
1170 .unwrap_or(DataValue::Null))
1171 }
1172 }
1173 "STDDEV" | "STDEV" => {
1174 if args.is_empty() {
1176 return Err(anyhow!("STDDEV requires 1 argument"));
1177 }
1178
1179 let column = match &args[0] {
1180 SqlExpression::Column(col) => col.clone(),
1181 _ => return Err(anyhow!("STDDEV argument must be a column")),
1182 };
1183
1184 Ok(context
1185 .get_frame_stddev(row_index, &column.name)
1186 .unwrap_or(DataValue::Null))
1187 }
1188 "VARIANCE" | "VAR" => {
1189 if args.is_empty() {
1191 return Err(anyhow!("VARIANCE requires 1 argument"));
1192 }
1193
1194 let column = match &args[0] {
1195 SqlExpression::Column(col) => col.clone(),
1196 _ => return Err(anyhow!("VARIANCE argument must be a column")),
1197 };
1198
1199 Ok(context
1200 .get_frame_variance(row_index, &column.name)
1201 .unwrap_or(DataValue::Null))
1202 }
1203 "MIN" => {
1204 if args.is_empty() {
1206 return Err(anyhow!("MIN requires 1 argument"));
1207 }
1208
1209 let column = match &args[0] {
1210 SqlExpression::Column(col) => col.clone(),
1211 _ => return Err(anyhow!("MIN argument must be a column")),
1212 };
1213
1214 let frame_rows = context.get_frame_rows(row_index);
1215 if frame_rows.is_empty() {
1216 return Ok(DataValue::Null);
1217 }
1218
1219 let source_table = context.source();
1220 let col_idx = source_table
1221 .get_column_index(&column.name)
1222 .ok_or_else(|| anyhow!("Column '{}' not found", column.name))?;
1223
1224 let mut min_value: Option<DataValue> = None;
1225 for &row_idx in &frame_rows {
1226 if let Some(value) = source_table.get_value(row_idx, col_idx) {
1227 if !matches!(value, DataValue::Null) {
1228 match &min_value {
1229 None => min_value = Some(value.clone()),
1230 Some(current_min) => {
1231 if value < current_min {
1232 min_value = Some(value.clone());
1233 }
1234 }
1235 }
1236 }
1237 }
1238 }
1239
1240 Ok(min_value.unwrap_or(DataValue::Null))
1241 }
1242 "MAX" => {
1243 if args.is_empty() {
1245 return Err(anyhow!("MAX requires 1 argument"));
1246 }
1247
1248 let column = match &args[0] {
1249 SqlExpression::Column(col) => col.clone(),
1250 _ => return Err(anyhow!("MAX argument must be a column")),
1251 };
1252
1253 let frame_rows = context.get_frame_rows(row_index);
1254 if frame_rows.is_empty() {
1255 return Ok(DataValue::Null);
1256 }
1257
1258 let source_table = context.source();
1259 let col_idx = source_table
1260 .get_column_index(&column.name)
1261 .ok_or_else(|| anyhow!("Column '{}' not found", column.name))?;
1262
1263 let mut max_value: Option<DataValue> = None;
1264 for &row_idx in &frame_rows {
1265 if let Some(value) = source_table.get_value(row_idx, col_idx) {
1266 if !matches!(value, DataValue::Null) {
1267 match &max_value {
1268 None => max_value = Some(value.clone()),
1269 Some(current_max) => {
1270 if value > current_max {
1271 max_value = Some(value.clone());
1272 }
1273 }
1274 }
1275 }
1276 }
1277 }
1278
1279 Ok(max_value.unwrap_or(DataValue::Null))
1280 }
1281 "COUNT" => {
1282 if args.is_empty() {
1286 if context.has_frame() {
1288 Ok(context
1289 .get_frame_count(row_index, None)
1290 .unwrap_or(DataValue::Null))
1291 } else {
1292 Ok(context
1293 .get_partition_count(row_index, None)
1294 .unwrap_or(DataValue::Null))
1295 }
1296 } else {
1297 let column = match &args[0] {
1299 SqlExpression::Column(col) => {
1300 if col.name == "*" {
1301 if context.has_frame() {
1303 return Ok(context
1304 .get_frame_count(row_index, None)
1305 .unwrap_or(DataValue::Null));
1306 } else {
1307 return Ok(context
1308 .get_partition_count(row_index, None)
1309 .unwrap_or(DataValue::Null));
1310 }
1311 }
1312 col.clone()
1313 }
1314 SqlExpression::StringLiteral(s) if s == "*" => {
1315 if context.has_frame() {
1317 return Ok(context
1318 .get_frame_count(row_index, None)
1319 .unwrap_or(DataValue::Null));
1320 } else {
1321 return Ok(context
1322 .get_partition_count(row_index, None)
1323 .unwrap_or(DataValue::Null));
1324 }
1325 }
1326 _ => return Err(anyhow!("COUNT argument must be a column or *")),
1327 };
1328
1329 if context.has_frame() {
1331 Ok(context
1332 .get_frame_count(row_index, Some(&column.name))
1333 .unwrap_or(DataValue::Null))
1334 } else {
1335 Ok(context
1336 .get_partition_count(row_index, Some(&column.name))
1337 .unwrap_or(DataValue::Null))
1338 }
1339 }
1340 }
1341 _ => Err(anyhow!("Unknown window function: {}", name)),
1342 };
1343
1344 let eval_time = eval_start.elapsed();
1345
1346 info!(
1347 "{} (built-in) evaluation: total={:.2}μs, context={:.2}μs, eval={:.2}μs",
1348 name_upper,
1349 func_start.elapsed().as_micros(),
1350 context_time.as_micros(),
1351 eval_time.as_micros()
1352 );
1353
1354 result
1355 }
1356
1357 fn evaluate_method_call(
1359 &mut self,
1360 object: &str,
1361 method: &str,
1362 args: &[SqlExpression],
1363 row_index: usize,
1364 ) -> Result<DataValue> {
1365 let col_index = self.table.get_column_index(object).ok_or_else(|| {
1367 let suggestion = self.find_similar_column(object);
1368 match suggestion {
1369 Some(similar) => {
1370 anyhow!("Column '{}' not found. Did you mean '{}'?", object, similar)
1371 }
1372 None => anyhow!("Column '{}' not found", object),
1373 }
1374 })?;
1375
1376 let cell_value = self.table.get_value(row_index, col_index).cloned();
1377
1378 self.evaluate_method_on_value(
1379 &cell_value.unwrap_or(DataValue::Null),
1380 method,
1381 args,
1382 row_index,
1383 )
1384 }
1385
1386 fn evaluate_method_on_value(
1388 &mut self,
1389 value: &DataValue,
1390 method: &str,
1391 args: &[SqlExpression],
1392 row_index: usize,
1393 ) -> Result<DataValue> {
1394 let function_name = match method.to_lowercase().as_str() {
1399 "trim" => "TRIM",
1400 "trimstart" | "trimbegin" => "TRIMSTART",
1401 "trimend" => "TRIMEND",
1402 "length" | "len" => "LENGTH",
1403 "contains" => "CONTAINS",
1404 "startswith" => "STARTSWITH",
1405 "endswith" => "ENDSWITH",
1406 "indexof" => "INDEXOF",
1407 _ => method, };
1409
1410 if self.function_registry.get(function_name).is_some() {
1412 debug!(
1413 "Proxying method '{}' through function registry as '{}'",
1414 method, function_name
1415 );
1416
1417 let mut func_args = vec![value.clone()];
1419
1420 for arg in args {
1422 func_args.push(self.evaluate(arg, row_index)?);
1423 }
1424
1425 let func = self.function_registry.get(function_name).unwrap();
1427 return func.evaluate(&func_args);
1428 }
1429
1430 Err(anyhow!(
1433 "Method '{}' not found. It should be registered in the function registry.",
1434 method
1435 ))
1436 }
1437
1438 fn evaluate_case_expression(
1440 &mut self,
1441 when_branches: &[crate::sql::recursive_parser::WhenBranch],
1442 else_branch: &Option<Box<SqlExpression>>,
1443 row_index: usize,
1444 ) -> Result<DataValue> {
1445 debug!(
1446 "ArithmeticEvaluator: evaluating CASE expression for row {}",
1447 row_index
1448 );
1449
1450 for branch in when_branches {
1452 let condition_result = self.evaluate_condition_as_bool(&branch.condition, row_index)?;
1454
1455 if condition_result {
1456 debug!("CASE: WHEN condition matched, evaluating result expression");
1457 return self.evaluate(&branch.result, row_index);
1458 }
1459 }
1460
1461 if let Some(else_expr) = else_branch {
1463 debug!("CASE: No WHEN matched, evaluating ELSE expression");
1464 self.evaluate(else_expr, row_index)
1465 } else {
1466 debug!("CASE: No WHEN matched and no ELSE, returning NULL");
1467 Ok(DataValue::Null)
1468 }
1469 }
1470
1471 fn evaluate_simple_case_expression(
1473 &mut self,
1474 expr: &Box<SqlExpression>,
1475 when_branches: &[crate::sql::parser::ast::SimpleWhenBranch],
1476 else_branch: &Option<Box<SqlExpression>>,
1477 row_index: usize,
1478 ) -> Result<DataValue> {
1479 debug!(
1480 "ArithmeticEvaluator: evaluating simple CASE expression for row {}",
1481 row_index
1482 );
1483
1484 let case_value = self.evaluate(expr, row_index)?;
1486 debug!("Simple CASE: evaluated expression to {:?}", case_value);
1487
1488 for branch in when_branches {
1490 let when_value = self.evaluate(&branch.value, row_index)?;
1492
1493 if self.values_equal(&case_value, &when_value)? {
1495 debug!("Simple CASE: WHEN value matched, evaluating result expression");
1496 return self.evaluate(&branch.result, row_index);
1497 }
1498 }
1499
1500 if let Some(else_expr) = else_branch {
1502 debug!("Simple CASE: No WHEN matched, evaluating ELSE expression");
1503 self.evaluate(else_expr, row_index)
1504 } else {
1505 debug!("Simple CASE: No WHEN matched and no ELSE, returning NULL");
1506 Ok(DataValue::Null)
1507 }
1508 }
1509
1510 fn values_equal(&self, left: &DataValue, right: &DataValue) -> Result<bool> {
1512 match (left, right) {
1513 (DataValue::Null, DataValue::Null) => Ok(true),
1514 (DataValue::Null, _) | (_, DataValue::Null) => Ok(false),
1515 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(a == b),
1516 (DataValue::Float(a), DataValue::Float(b)) => Ok((a - b).abs() < f64::EPSILON),
1517 (DataValue::String(a), DataValue::String(b)) => Ok(a == b),
1518 (DataValue::Boolean(a), DataValue::Boolean(b)) => Ok(a == b),
1519 (DataValue::DateTime(a), DataValue::DateTime(b)) => Ok(a == b),
1520 (DataValue::Integer(a), DataValue::Float(b)) => {
1522 Ok((*a as f64 - b).abs() < f64::EPSILON)
1523 }
1524 (DataValue::Float(a), DataValue::Integer(b)) => {
1525 Ok((a - *b as f64).abs() < f64::EPSILON)
1526 }
1527 _ => Ok(false),
1528 }
1529 }
1530
1531 fn evaluate_condition_as_bool(
1533 &mut self,
1534 expr: &SqlExpression,
1535 row_index: usize,
1536 ) -> Result<bool> {
1537 let value = self.evaluate(expr, row_index)?;
1538
1539 match value {
1540 DataValue::Boolean(b) => Ok(b),
1541 DataValue::Integer(i) => Ok(i != 0),
1542 DataValue::Float(f) => Ok(f != 0.0),
1543 DataValue::Null => Ok(false),
1544 DataValue::String(s) => Ok(!s.is_empty()),
1545 DataValue::InternedString(s) => Ok(!s.is_empty()),
1546 _ => Ok(true), }
1548 }
1549
1550 fn evaluate_datetime_constructor(
1552 &self,
1553 year: i32,
1554 month: u32,
1555 day: u32,
1556 hour: Option<u32>,
1557 minute: Option<u32>,
1558 second: Option<u32>,
1559 ) -> Result<DataValue> {
1560 use chrono::{NaiveDate, TimeZone, Utc};
1561
1562 let date = NaiveDate::from_ymd_opt(year, month, day)
1564 .ok_or_else(|| anyhow!("Invalid date: {}-{}-{}", year, month, day))?;
1565
1566 let hour = hour.unwrap_or(0);
1568 let minute = minute.unwrap_or(0);
1569 let second = second.unwrap_or(0);
1570
1571 let naive_datetime = date
1572 .and_hms_opt(hour, minute, second)
1573 .ok_or_else(|| anyhow!("Invalid time: {}:{}:{}", hour, minute, second))?;
1574
1575 let datetime = Utc.from_utc_datetime(&naive_datetime);
1577
1578 let datetime_str = datetime.format("%Y-%m-%d %H:%M:%S%.3f").to_string();
1580 Ok(DataValue::String(datetime_str))
1581 }
1582
1583 fn evaluate_datetime_today(
1585 &self,
1586 hour: Option<u32>,
1587 minute: Option<u32>,
1588 second: Option<u32>,
1589 ) -> Result<DataValue> {
1590 use chrono::{TimeZone, Utc};
1591
1592 let today = Utc::now().date_naive();
1594
1595 let hour = hour.unwrap_or(0);
1597 let minute = minute.unwrap_or(0);
1598 let second = second.unwrap_or(0);
1599
1600 let naive_datetime = today
1601 .and_hms_opt(hour, minute, second)
1602 .ok_or_else(|| anyhow!("Invalid time: {}:{}:{}", hour, minute, second))?;
1603
1604 let datetime = Utc.from_utc_datetime(&naive_datetime);
1606
1607 let datetime_str = datetime.format("%Y-%m-%d %H:%M:%S%.3f").to_string();
1609 Ok(DataValue::String(datetime_str))
1610 }
1611}
1612
1613#[cfg(test)]
1614mod tests {
1615 use super::*;
1616 use crate::data::datatable::{DataColumn, DataRow};
1617
1618 fn create_test_table() -> DataTable {
1619 let mut table = DataTable::new("test");
1620 table.add_column(DataColumn::new("a"));
1621 table.add_column(DataColumn::new("b"));
1622 table.add_column(DataColumn::new("c"));
1623
1624 table
1625 .add_row(DataRow::new(vec![
1626 DataValue::Integer(10),
1627 DataValue::Float(2.5),
1628 DataValue::Integer(4),
1629 ]))
1630 .unwrap();
1631
1632 table
1633 }
1634
1635 #[test]
1636 fn test_evaluate_column() {
1637 let table = create_test_table();
1638 let mut evaluator = ArithmeticEvaluator::new(&table);
1639
1640 let expr = SqlExpression::Column(ColumnRef::unquoted("a".to_string()));
1641 let result = evaluator.evaluate(&expr, 0).unwrap();
1642 assert_eq!(result, DataValue::Integer(10));
1643 }
1644
1645 #[test]
1646 fn test_evaluate_between_column_in_range() {
1647 let table = create_test_table();
1648 let mut evaluator = ArithmeticEvaluator::new(&table);
1649
1650 let expr = SqlExpression::Between {
1652 expr: Box::new(SqlExpression::Column(ColumnRef::unquoted("a".to_string()))),
1653 lower: Box::new(SqlExpression::NumberLiteral("5".to_string())),
1654 upper: Box::new(SqlExpression::NumberLiteral("20".to_string())),
1655 };
1656 assert_eq!(
1657 evaluator.evaluate(&expr, 0).unwrap(),
1658 DataValue::Boolean(true)
1659 );
1660 }
1661
1662 #[test]
1663 fn test_evaluate_between_column_out_of_range() {
1664 let table = create_test_table();
1665 let mut evaluator = ArithmeticEvaluator::new(&table);
1666
1667 let expr = SqlExpression::Between {
1669 expr: Box::new(SqlExpression::Column(ColumnRef::unquoted("a".to_string()))),
1670 lower: Box::new(SqlExpression::NumberLiteral("11".to_string())),
1671 upper: Box::new(SqlExpression::NumberLiteral("20".to_string())),
1672 };
1673 assert_eq!(
1674 evaluator.evaluate(&expr, 0).unwrap(),
1675 DataValue::Boolean(false)
1676 );
1677 }
1678
1679 #[test]
1680 fn test_evaluate_between_endpoints_inclusive() {
1681 let table = create_test_table();
1682 let mut evaluator = ArithmeticEvaluator::new(&table);
1683
1684 let expr = SqlExpression::Between {
1686 expr: Box::new(SqlExpression::Column(ColumnRef::unquoted("a".to_string()))),
1687 lower: Box::new(SqlExpression::NumberLiteral("10".to_string())),
1688 upper: Box::new(SqlExpression::NumberLiteral("10".to_string())),
1689 };
1690 assert_eq!(
1691 evaluator.evaluate(&expr, 0).unwrap(),
1692 DataValue::Boolean(true)
1693 );
1694 }
1695
1696 #[test]
1697 fn test_evaluate_number_literal() {
1698 let table = create_test_table();
1699 let mut evaluator = ArithmeticEvaluator::new(&table);
1700
1701 let expr = SqlExpression::NumberLiteral("42".to_string());
1702 let result = evaluator.evaluate(&expr, 0).unwrap();
1703 assert_eq!(result, DataValue::Integer(42));
1704
1705 let expr = SqlExpression::NumberLiteral("3.14".to_string());
1706 let result = evaluator.evaluate(&expr, 0).unwrap();
1707 assert_eq!(result, DataValue::Float(3.14));
1708 }
1709
1710 #[test]
1711 fn test_add_values() {
1712 let table = create_test_table();
1713 let mut evaluator = ArithmeticEvaluator::new(&table);
1714
1715 let result = evaluator
1717 .add_values(&DataValue::Integer(5), &DataValue::Integer(3))
1718 .unwrap();
1719 assert_eq!(result, DataValue::Integer(8));
1720
1721 let result = evaluator
1723 .add_values(&DataValue::Integer(5), &DataValue::Float(2.5))
1724 .unwrap();
1725 assert_eq!(result, DataValue::Float(7.5));
1726 }
1727
1728 #[test]
1729 fn test_multiply_values() {
1730 let table = create_test_table();
1731 let mut evaluator = ArithmeticEvaluator::new(&table);
1732
1733 let result = evaluator
1735 .multiply_values(&DataValue::Integer(4), &DataValue::Float(2.5))
1736 .unwrap();
1737 assert_eq!(result, DataValue::Float(10.0));
1738 }
1739
1740 #[test]
1741 fn test_divide_values() {
1742 let table = create_test_table();
1743 let mut evaluator = ArithmeticEvaluator::new(&table);
1744
1745 let result = evaluator
1747 .divide_values(&DataValue::Integer(10), &DataValue::Integer(2))
1748 .unwrap();
1749 assert_eq!(result, DataValue::Integer(5));
1750
1751 let result = evaluator
1753 .divide_values(&DataValue::Integer(10), &DataValue::Integer(3))
1754 .unwrap();
1755 assert_eq!(result, DataValue::Float(10.0 / 3.0));
1756 }
1757
1758 #[test]
1759 fn test_division_by_zero() {
1760 let table = create_test_table();
1761 let mut evaluator = ArithmeticEvaluator::new(&table);
1762
1763 let result = evaluator.divide_values(&DataValue::Integer(10), &DataValue::Integer(0));
1764 assert!(result.is_err());
1765 assert!(result.unwrap_err().to_string().contains("Division by zero"));
1766 }
1767
1768 #[test]
1769 fn test_binary_op_expression() {
1770 let table = create_test_table();
1771 let mut evaluator = ArithmeticEvaluator::new(&table);
1772
1773 let expr = SqlExpression::BinaryOp {
1775 left: Box::new(SqlExpression::Column(ColumnRef::unquoted("a".to_string()))),
1776 op: "*".to_string(),
1777 right: Box::new(SqlExpression::Column(ColumnRef::unquoted("b".to_string()))),
1778 };
1779
1780 let result = evaluator.evaluate(&expr, 0).unwrap();
1781 assert_eq!(result, DataValue::Float(25.0));
1782 }
1783}