1use crate::config::global::get_date_notation;
2use crate::data::data_view::DataView;
3use crate::data::datatable::{DataTable, DataValue};
4use crate::data::value_comparisons::compare_with_op;
5use crate::sql::aggregate_functions::AggregateFunctionRegistry; use crate::sql::aggregates::AggregateRegistry; use crate::sql::functions::FunctionRegistry;
8use crate::sql::parser::ast::{ColumnRef, WindowSpec};
9use crate::sql::recursive_parser::SqlExpression;
10use crate::sql::window_context::WindowContext;
11use crate::sql::window_functions::{ExpressionEvaluator, WindowFunctionRegistry};
12use anyhow::{anyhow, Result};
13use std::collections::{HashMap, HashSet};
14use std::sync::Arc;
15use std::time::Instant;
16use tracing::{debug, info};
17
18pub struct ArithmeticEvaluator<'a> {
21 table: &'a DataTable,
22 _date_notation: String,
23 function_registry: Arc<FunctionRegistry>,
24 aggregate_registry: Arc<AggregateRegistry>, new_aggregate_registry: Arc<AggregateFunctionRegistry>, window_function_registry: Arc<WindowFunctionRegistry>,
27 visible_rows: Option<Vec<usize>>, window_contexts: HashMap<u64, Arc<WindowContext>>, table_aliases: HashMap<String, String>, }
31
32impl<'a> ArithmeticEvaluator<'a> {
33 #[must_use]
34 pub fn new(table: &'a DataTable) -> Self {
35 Self {
36 table,
37 _date_notation: get_date_notation(),
38 function_registry: Arc::new(FunctionRegistry::new()),
39 aggregate_registry: Arc::new(AggregateRegistry::new()),
40 new_aggregate_registry: Arc::new(AggregateFunctionRegistry::new()),
41 window_function_registry: Arc::new(WindowFunctionRegistry::new()),
42 visible_rows: None,
43 window_contexts: HashMap::new(),
44 table_aliases: HashMap::new(),
45 }
46 }
47
48 #[must_use]
49 pub fn with_date_notation(table: &'a DataTable, date_notation: String) -> Self {
50 Self {
51 table,
52 _date_notation: date_notation,
53 function_registry: Arc::new(FunctionRegistry::new()),
54 aggregate_registry: Arc::new(AggregateRegistry::new()),
55 new_aggregate_registry: Arc::new(AggregateFunctionRegistry::new()),
56 window_function_registry: Arc::new(WindowFunctionRegistry::new()),
57 visible_rows: None,
58 window_contexts: HashMap::new(),
59 table_aliases: HashMap::new(),
60 }
61 }
62
63 #[must_use]
65 pub fn with_visible_rows(mut self, rows: Vec<usize>) -> Self {
66 self.visible_rows = Some(rows);
67 self
68 }
69
70 #[must_use]
72 pub fn with_table_aliases(mut self, aliases: HashMap<String, String>) -> Self {
73 self.table_aliases = aliases;
74 self
75 }
76
77 #[must_use]
78 pub fn with_date_notation_and_registry(
79 table: &'a DataTable,
80 date_notation: String,
81 function_registry: Arc<FunctionRegistry>,
82 ) -> Self {
83 Self {
84 table,
85 _date_notation: date_notation,
86 function_registry,
87 aggregate_registry: Arc::new(AggregateRegistry::new()),
88 new_aggregate_registry: Arc::new(AggregateFunctionRegistry::new()),
89 window_function_registry: Arc::new(WindowFunctionRegistry::new()),
90 visible_rows: None,
91 window_contexts: HashMap::new(),
92 table_aliases: HashMap::new(),
93 }
94 }
95
96 fn find_similar_column(&self, name: &str) -> Option<String> {
98 let columns = self.table.column_names();
99 let mut best_match: Option<(String, usize)> = None;
100
101 for col in columns {
102 let distance = self.edit_distance(&col.to_lowercase(), &name.to_lowercase());
103 let max_distance = if name.len() > 10 { 3 } else { 2 };
106 if distance <= max_distance {
107 match &best_match {
108 None => best_match = Some((col, distance)),
109 Some((_, best_dist)) if distance < *best_dist => {
110 best_match = Some((col, distance));
111 }
112 _ => {}
113 }
114 }
115 }
116
117 best_match.map(|(name, _)| name)
118 }
119
120 fn edit_distance(&self, s1: &str, s2: &str) -> usize {
122 crate::sql::functions::string_methods::EditDistanceFunction::calculate_edit_distance(s1, s2)
124 }
125
126 pub fn evaluate(&mut self, expr: &SqlExpression, row_index: usize) -> Result<DataValue> {
128 debug!(
129 "ArithmeticEvaluator: evaluating {:?} for row {}",
130 expr, row_index
131 );
132
133 match expr {
134 SqlExpression::Column(column_ref) => self.evaluate_column_ref(column_ref, row_index),
135 SqlExpression::StringLiteral(s) => Ok(DataValue::String(s.clone())),
136 SqlExpression::BooleanLiteral(b) => Ok(DataValue::Boolean(*b)),
137 SqlExpression::NumberLiteral(n) => self.evaluate_number_literal(n),
138 SqlExpression::Null => Ok(DataValue::Null),
139 SqlExpression::BinaryOp { left, op, right } => {
140 self.evaluate_binary_op(left, op, right, row_index)
141 }
142 SqlExpression::FunctionCall {
143 name,
144 args,
145 distinct,
146 } => self.evaluate_function_with_distinct(name, args, *distinct, row_index),
147 SqlExpression::WindowFunction {
148 name,
149 args,
150 window_spec,
151 } => self.evaluate_window_function(name, args, window_spec, row_index),
152 SqlExpression::MethodCall {
153 object,
154 method,
155 args,
156 } => self.evaluate_method_call(object, method, args, row_index),
157 SqlExpression::ChainedMethodCall { base, method, args } => {
158 let base_value = self.evaluate(base, row_index)?;
160 self.evaluate_method_on_value(&base_value, method, args, row_index)
161 }
162 SqlExpression::CaseExpression {
163 when_branches,
164 else_branch,
165 } => self.evaluate_case_expression(when_branches, else_branch, row_index),
166 SqlExpression::SimpleCaseExpression {
167 expr,
168 when_branches,
169 else_branch,
170 } => self.evaluate_simple_case_expression(expr, when_branches, else_branch, row_index),
171 SqlExpression::DateTimeConstructor {
172 year,
173 month,
174 day,
175 hour,
176 minute,
177 second,
178 } => self.evaluate_datetime_constructor(*year, *month, *day, *hour, *minute, *second),
179 SqlExpression::DateTimeToday {
180 hour,
181 minute,
182 second,
183 } => self.evaluate_datetime_today(*hour, *minute, *second),
184 _ => Err(anyhow!(
185 "Unsupported expression type for arithmetic evaluation: {:?}",
186 expr
187 )),
188 }
189 }
190
191 fn evaluate_column_ref(&self, column_ref: &ColumnRef, row_index: usize) -> Result<DataValue> {
193 if let Some(table_prefix) = &column_ref.table_prefix {
194 let actual_table = self
196 .table_aliases
197 .get(table_prefix)
198 .map(|s| s.as_str())
199 .unwrap_or(table_prefix);
200
201 let qualified_name = format!("{}.{}", actual_table, column_ref.name);
203
204 if let Some(col_idx) = self.table.find_column_by_qualified_name(&qualified_name) {
205 debug!(
206 "Resolved {}.{} -> '{}' at index {}",
207 table_prefix, column_ref.name, qualified_name, col_idx
208 );
209 return self
210 .table
211 .get_value(row_index, col_idx)
212 .ok_or_else(|| anyhow!("Row {} out of bounds", row_index))
213 .map(|v| v.clone());
214 }
215
216 if let Some(col_idx) = self.table.get_column_index(&column_ref.name) {
218 debug!(
219 "Resolved {}.{} -> unqualified '{}' at index {}",
220 table_prefix, column_ref.name, column_ref.name, col_idx
221 );
222 return self
223 .table
224 .get_value(row_index, col_idx)
225 .ok_or_else(|| anyhow!("Row {} out of bounds", row_index))
226 .map(|v| v.clone());
227 }
228
229 Err(anyhow!(
231 "Column '{}' not found. Table '{}' may not support qualified column names",
232 qualified_name,
233 actual_table
234 ))
235 } else {
236 self.evaluate_column(&column_ref.name, row_index)
238 }
239 }
240
241 fn evaluate_column(&self, column_name: &str, row_index: usize) -> Result<DataValue> {
243 let resolved_column = if column_name.contains('.') {
245 if let Some(dot_pos) = column_name.rfind('.') {
247 let _table_or_alias = &column_name[..dot_pos];
248 let col_name = &column_name[dot_pos + 1..];
249
250 debug!(
253 "Resolving qualified column: {} -> {}",
254 column_name, col_name
255 );
256 col_name.to_string()
257 } else {
258 column_name.to_string()
259 }
260 } else {
261 column_name.to_string()
262 };
263
264 let col_index = if let Some(idx) = self.table.get_column_index(&resolved_column) {
265 idx
266 } else if resolved_column != column_name {
267 if let Some(idx) = self.table.get_column_index(column_name) {
269 idx
270 } else {
271 let suggestion = self.find_similar_column(&resolved_column);
272 return Err(match suggestion {
273 Some(similar) => anyhow!(
274 "Column '{}' not found. Did you mean '{}'?",
275 column_name,
276 similar
277 ),
278 None => anyhow!("Column '{}' not found", column_name),
279 });
280 }
281 } else {
282 let suggestion = self.find_similar_column(&resolved_column);
283 return Err(match suggestion {
284 Some(similar) => anyhow!(
285 "Column '{}' not found. Did you mean '{}'?",
286 column_name,
287 similar
288 ),
289 None => anyhow!("Column '{}' not found", column_name),
290 });
291 };
292
293 if row_index >= self.table.row_count() {
294 return Err(anyhow!("Row index {} out of bounds", row_index));
295 }
296
297 let row = self
298 .table
299 .get_row(row_index)
300 .ok_or_else(|| anyhow!("Row {} not found", row_index))?;
301
302 let value = row
303 .get(col_index)
304 .ok_or_else(|| anyhow!("Column index {} out of bounds for row", col_index))?;
305
306 Ok(value.clone())
307 }
308
309 fn evaluate_number_literal(&self, number_str: &str) -> Result<DataValue> {
311 if let Ok(int_val) = number_str.parse::<i64>() {
313 return Ok(DataValue::Integer(int_val));
314 }
315
316 if let Ok(float_val) = number_str.parse::<f64>() {
318 return Ok(DataValue::Float(float_val));
319 }
320
321 Err(anyhow!("Invalid number literal: {}", number_str))
322 }
323
324 fn evaluate_binary_op(
326 &mut self,
327 left: &SqlExpression,
328 op: &str,
329 right: &SqlExpression,
330 row_index: usize,
331 ) -> Result<DataValue> {
332 let left_val = self.evaluate(left, row_index)?;
333 let right_val = self.evaluate(right, row_index)?;
334
335 debug!(
336 "ArithmeticEvaluator: {} {} {}",
337 self.format_value(&left_val),
338 op,
339 self.format_value(&right_val)
340 );
341
342 match op {
343 "+" => self.add_values(&left_val, &right_val),
344 "-" => self.subtract_values(&left_val, &right_val),
345 "*" => self.multiply_values(&left_val, &right_val),
346 "/" => self.divide_values(&left_val, &right_val),
347 "%" => {
348 let args = vec![left.clone(), right.clone()];
350 self.evaluate_function("MOD", &args, row_index)
351 }
352 ">" | "<" | ">=" | "<=" | "=" | "!=" | "<>" => {
355 let result = compare_with_op(&left_val, &right_val, op, false);
356 Ok(DataValue::Boolean(result))
357 }
358 "IS NULL" => Ok(DataValue::Boolean(matches!(left_val, DataValue::Null))),
360 "IS NOT NULL" => Ok(DataValue::Boolean(!matches!(left_val, DataValue::Null))),
361 "AND" => {
363 let left_bool = self.to_bool(&left_val)?;
364 let right_bool = self.to_bool(&right_val)?;
365 Ok(DataValue::Boolean(left_bool && right_bool))
366 }
367 "OR" => {
368 let left_bool = self.to_bool(&left_val)?;
369 let right_bool = self.to_bool(&right_val)?;
370 Ok(DataValue::Boolean(left_bool || right_bool))
371 }
372 "LIKE" => {
374 let text = self.value_to_string(&left_val);
375 let pattern = self.value_to_string(&right_val);
376 let matches = self.sql_like_match(&text, &pattern);
377 Ok(DataValue::Boolean(matches))
378 }
379 _ => Err(anyhow!("Unsupported arithmetic operator: {}", op)),
380 }
381 }
382
383 fn add_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
385 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
387 return Ok(DataValue::Null);
388 }
389
390 match (left, right) {
391 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a + b)),
392 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 + b)),
393 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a + *b as f64)),
394 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a + b)),
395 _ => Err(anyhow!("Cannot add {:?} and {:?}", left, right)),
396 }
397 }
398
399 fn subtract_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
401 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
403 return Ok(DataValue::Null);
404 }
405
406 match (left, right) {
407 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a - b)),
408 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 - b)),
409 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a - *b as f64)),
410 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a - b)),
411 _ => Err(anyhow!("Cannot subtract {:?} and {:?}", left, right)),
412 }
413 }
414
415 fn multiply_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
417 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
419 return Ok(DataValue::Null);
420 }
421
422 match (left, right) {
423 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a * b)),
424 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 * b)),
425 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a * *b as f64)),
426 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a * b)),
427 _ => Err(anyhow!("Cannot multiply {:?} and {:?}", left, right)),
428 }
429 }
430
431 fn divide_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
433 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
435 return Ok(DataValue::Null);
436 }
437
438 let is_zero = match right {
440 DataValue::Integer(0) => true,
441 DataValue::Float(f) if *f == 0.0 => true, _ => false,
443 };
444
445 if is_zero {
446 return Err(anyhow!("Division by zero"));
447 }
448
449 match (left, right) {
450 (DataValue::Integer(a), DataValue::Integer(b)) => {
451 if a % b == 0 {
453 Ok(DataValue::Integer(a / b))
454 } else {
455 Ok(DataValue::Float(*a as f64 / *b as f64))
456 }
457 }
458 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 / b)),
459 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a / *b as f64)),
460 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a / b)),
461 _ => Err(anyhow!("Cannot divide {:?} and {:?}", left, right)),
462 }
463 }
464
465 fn format_value(&self, value: &DataValue) -> String {
467 match value {
468 DataValue::Integer(i) => i.to_string(),
469 DataValue::Float(f) => f.to_string(),
470 DataValue::String(s) => format!("'{s}'"),
471 _ => format!("{value:?}"),
472 }
473 }
474
475 fn to_bool(&self, value: &DataValue) -> Result<bool> {
477 match value {
478 DataValue::Boolean(b) => Ok(*b),
479 DataValue::Integer(i) => Ok(*i != 0),
480 DataValue::Float(f) => Ok(*f != 0.0),
481 DataValue::Null => Ok(false),
482 _ => Err(anyhow!("Cannot convert {:?} to boolean", value)),
483 }
484 }
485
486 fn value_to_string(&self, value: &DataValue) -> String {
488 match value {
489 DataValue::String(s) => s.clone(),
490 DataValue::InternedString(s) => s.to_string(),
491 DataValue::Integer(i) => i.to_string(),
492 DataValue::Float(f) => f.to_string(),
493 DataValue::Boolean(b) => b.to_string(),
494 DataValue::DateTime(dt) => dt.to_string(),
495 DataValue::Vector(v) => {
496 let components: Vec<String> = v.iter().map(|f| f.to_string()).collect();
498 format!("[{}]", components.join(","))
499 }
500 DataValue::Null => String::new(),
501 }
502 }
503
504 fn sql_like_match(&self, text: &str, pattern: &str) -> bool {
507 let pattern_chars: Vec<char> = pattern.chars().collect();
508 let text_chars: Vec<char> = text.chars().collect();
509
510 self.like_match_recursive(&text_chars, 0, &pattern_chars, 0)
511 }
512
513 fn like_match_recursive(
515 &self,
516 text: &[char],
517 text_pos: usize,
518 pattern: &[char],
519 pattern_pos: usize,
520 ) -> bool {
521 if pattern_pos >= pattern.len() {
523 return text_pos >= text.len();
524 }
525
526 if pattern[pattern_pos] == '%' {
528 if self.like_match_recursive(text, text_pos, pattern, pattern_pos + 1) {
530 return true;
531 }
532 if text_pos < text.len() {
534 return self.like_match_recursive(text, text_pos + 1, pattern, pattern_pos);
535 }
536 return false;
537 }
538
539 if text_pos >= text.len() {
541 return false;
542 }
543
544 if pattern[pattern_pos] == '_' {
546 return self.like_match_recursive(text, text_pos + 1, pattern, pattern_pos + 1);
547 }
548
549 if text[text_pos] == pattern[pattern_pos] {
551 return self.like_match_recursive(text, text_pos + 1, pattern, pattern_pos + 1);
552 }
553
554 false
555 }
556
557 fn evaluate_function_with_distinct(
559 &mut self,
560 name: &str,
561 args: &[SqlExpression],
562 distinct: bool,
563 row_index: usize,
564 ) -> Result<DataValue> {
565 if distinct {
567 let name_upper = name.to_uppercase();
568
569 if self.aggregate_registry.is_aggregate(&name_upper)
571 || self.new_aggregate_registry.contains(&name_upper)
572 {
573 return self.evaluate_aggregate_with_distinct(&name_upper, args, row_index);
574 } else {
575 return Err(anyhow!(
576 "DISTINCT can only be used with aggregate functions"
577 ));
578 }
579 }
580
581 self.evaluate_function(name, args, row_index)
583 }
584
585 fn evaluate_aggregate_with_distinct(
586 &mut self,
587 name: &str,
588 args: &[SqlExpression],
589 _row_index: usize,
590 ) -> Result<DataValue> {
591 let name_upper = name.to_uppercase();
592
593 if self.new_aggregate_registry.get(&name_upper).is_some() {
595 let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
596 visible.clone()
597 } else {
598 (0..self.table.rows.len()).collect()
599 };
600
601 let mut vals = Vec::new();
603 for &row_idx in &rows_to_process {
604 if !args.is_empty() {
605 let value = self.evaluate(&args[0], row_idx)?;
606 vals.push(value);
607 }
608 }
609
610 let mut seen = HashSet::new();
612 let unique_values: Vec<_> = vals
613 .into_iter()
614 .filter(|v| {
615 let key = format!("{:?}", v);
616 seen.insert(key)
617 })
618 .collect();
619
620 let agg_func = self.new_aggregate_registry.get(&name_upper).unwrap();
622 let mut state = agg_func.create_state();
623
624 for value in &unique_values {
626 state.accumulate(value)?;
627 }
628
629 return Ok(state.finalize());
630 }
631
632 if self.aggregate_registry.get(&name_upper).is_some() {
634 let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
636 visible.clone()
637 } else {
638 (0..self.table.rows.len()).collect()
639 };
640
641 if name_upper == "STRING_AGG" && args.len() >= 2 {
643 let mut state = crate::sql::aggregates::AggregateState::StringAgg(
645 if args.len() >= 2 {
647 let separator = self.evaluate(&args[1], 0)?; match separator {
649 DataValue::String(s) => crate::sql::aggregates::StringAggState::new(&s),
650 DataValue::InternedString(s) => {
651 crate::sql::aggregates::StringAggState::new(&s)
652 }
653 _ => crate::sql::aggregates::StringAggState::new(","), }
655 } else {
656 crate::sql::aggregates::StringAggState::new(",")
657 },
658 );
659
660 let mut seen_values = HashSet::new();
663
664 for &row_idx in &rows_to_process {
665 let value = self.evaluate(&args[0], row_idx)?;
666
667 if !seen_values.insert(value.clone()) {
669 continue; }
671
672 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
674 agg_func.accumulate(&mut state, &value)?;
675 }
676
677 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
679 return Ok(agg_func.finalize(state));
680 }
681
682 let mut vals = Vec::new();
685 for &row_idx in &rows_to_process {
686 if !args.is_empty() {
687 let value = self.evaluate(&args[0], row_idx)?;
688 vals.push(value);
689 }
690 }
691
692 let mut seen = HashSet::new();
694 let mut unique_values = Vec::new();
695 for value in vals {
696 if seen.insert(value.clone()) {
697 unique_values.push(value);
698 }
699 }
700
701 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
703 let mut state = agg_func.init();
704
705 for value in &unique_values {
707 agg_func.accumulate(&mut state, value)?;
708 }
709
710 return Ok(agg_func.finalize(state));
711 }
712
713 Err(anyhow!("Unknown aggregate function: {}", name))
714 }
715
716 fn evaluate_function(
717 &mut self,
718 name: &str,
719 args: &[SqlExpression],
720 row_index: usize,
721 ) -> Result<DataValue> {
722 let name_upper = name.to_uppercase();
724
725 if self.new_aggregate_registry.get(&name_upper).is_some() {
727 let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
729 visible.clone()
730 } else {
731 (0..self.table.rows.len()).collect()
732 };
733
734 let agg_func = self.new_aggregate_registry.get(&name_upper).unwrap();
736 let mut state = agg_func.create_state();
737
738 if name_upper == "COUNT" || name_upper == "COUNT_STAR" {
740 if args.is_empty()
741 || (args.len() == 1
742 && matches!(&args[0], SqlExpression::Column(col) if col.name == "*"))
743 || (args.len() == 1
744 && matches!(&args[0], SqlExpression::StringLiteral(s) if s == "*"))
745 {
746 for _ in &rows_to_process {
748 state.accumulate(&DataValue::Integer(1))?;
749 }
750 } else {
751 for &row_idx in &rows_to_process {
753 let value = self.evaluate(&args[0], row_idx)?;
754 state.accumulate(&value)?;
755 }
756 }
757 } else {
758 if !args.is_empty() {
760 for &row_idx in &rows_to_process {
761 let value = self.evaluate(&args[0], row_idx)?;
762 state.accumulate(&value)?;
763 }
764 }
765 }
766
767 return Ok(state.finalize());
768 }
769
770 if self.aggregate_registry.get(&name_upper).is_some() {
772 let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
774 visible.clone()
775 } else {
776 (0..self.table.rows.len()).collect()
777 };
778
779 if name_upper == "STRING_AGG" && args.len() >= 2 {
781 let mut state = crate::sql::aggregates::AggregateState::StringAgg(
783 if args.len() >= 2 {
785 let separator = self.evaluate(&args[1], 0)?; match separator {
787 DataValue::String(s) => crate::sql::aggregates::StringAggState::new(&s),
788 DataValue::InternedString(s) => {
789 crate::sql::aggregates::StringAggState::new(&s)
790 }
791 _ => crate::sql::aggregates::StringAggState::new(","), }
793 } else {
794 crate::sql::aggregates::StringAggState::new(",")
795 },
796 );
797
798 for &row_idx in &rows_to_process {
800 let value = self.evaluate(&args[0], row_idx)?;
801 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
803 agg_func.accumulate(&mut state, &value)?;
804 }
805
806 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
808 return Ok(agg_func.finalize(state));
809 }
810
811 let values = if !args.is_empty()
813 && !(args.len() == 1
814 && matches!(&args[0], SqlExpression::Column(c) if c.name == "*"))
815 {
816 let mut vals = Vec::new();
818 for &row_idx in &rows_to_process {
819 let value = self.evaluate(&args[0], row_idx)?;
820 vals.push(value);
821 }
822 Some(vals)
823 } else {
824 None
825 };
826
827 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
829 let mut state = agg_func.init();
830
831 if let Some(values) = values {
832 for value in &values {
834 agg_func.accumulate(&mut state, value)?;
835 }
836 } else {
837 for _ in &rows_to_process {
839 agg_func.accumulate(&mut state, &DataValue::Integer(1))?;
840 }
841 }
842
843 return Ok(agg_func.finalize(state));
844 }
845
846 if self.function_registry.get(name).is_some() {
848 let mut evaluated_args = Vec::new();
850 for arg in args {
851 evaluated_args.push(self.evaluate(arg, row_index)?);
852 }
853
854 let func = self.function_registry.get(name).unwrap();
856 return func.evaluate(&evaluated_args);
857 }
858
859 Err(anyhow!("Unknown function: {}", name))
861 }
862
863 pub fn get_or_create_window_context(
866 &mut self,
867 spec: &WindowSpec,
868 ) -> Result<Arc<WindowContext>> {
869 let overall_start = Instant::now();
870
871 let key = spec.compute_hash();
873
874 if let Some(context) = self.window_contexts.get(&key) {
875 info!(
876 "WindowContext cache hit for spec (lookup: {:.2}μs)",
877 overall_start.elapsed().as_micros()
878 );
879 return Ok(Arc::clone(context));
880 }
881
882 info!("WindowContext cache miss - creating new context");
883 let dataview_start = Instant::now();
884
885 let data_view = if let Some(ref _visible_rows) = self.visible_rows {
887 let view = DataView::new(Arc::new(self.table.clone()));
889 view
892 } else {
893 DataView::new(Arc::new(self.table.clone()))
894 };
895
896 info!(
897 "DataView creation took {:.2}μs",
898 dataview_start.elapsed().as_micros()
899 );
900 let context_start = Instant::now();
901
902 let context = WindowContext::new_with_spec(Arc::new(data_view), spec.clone())?;
904
905 info!(
906 "WindowContext::new_with_spec took {:.2}ms (rows: {})",
907 context_start.elapsed().as_secs_f64() * 1000.0,
908 self.table.row_count()
909 );
910
911 let context = Arc::new(context);
912 self.window_contexts.insert(key, Arc::clone(&context));
913
914 info!(
915 "Total WindowContext creation (cache miss) took {:.2}ms",
916 overall_start.elapsed().as_secs_f64() * 1000.0
917 );
918
919 Ok(context)
920 }
921
922 fn evaluate_window_function(
924 &mut self,
925 name: &str,
926 args: &[SqlExpression],
927 spec: &WindowSpec,
928 row_index: usize,
929 ) -> Result<DataValue> {
930 let func_start = Instant::now();
931 let name_upper = name.to_uppercase();
932
933 debug!("Looking for window function {} in registry", name_upper);
935 if let Some(window_fn_arc) = self.window_function_registry.get(&name_upper) {
936 debug!("Found window function {} in registry", name_upper);
937
938 let window_fn = window_fn_arc.as_ref();
940
941 window_fn.validate_args(args)?;
943
944 let transformed_spec = window_fn.transform_window_spec(spec, args)?;
946
947 let context = self.get_or_create_window_context(&transformed_spec)?;
949
950 struct EvaluatorAdapter<'a, 'b> {
952 evaluator: &'a mut ArithmeticEvaluator<'b>,
953 row_index: usize,
954 }
955
956 impl<'a, 'b> ExpressionEvaluator for EvaluatorAdapter<'a, 'b> {
957 fn evaluate(
958 &mut self,
959 expr: &SqlExpression,
960 row_index: usize,
961 ) -> Result<DataValue> {
962 self.evaluator.evaluate(expr, row_index)
963 }
964 }
965
966 let mut adapter = EvaluatorAdapter {
967 evaluator: self,
968 row_index,
969 };
970
971 let compute_start = Instant::now();
972 let result = window_fn.compute(&context, row_index, args, &mut adapter);
974
975 info!(
976 "{} (registry) evaluation: total={:.2}μs, compute={:.2}μs",
977 name_upper,
978 func_start.elapsed().as_micros(),
979 compute_start.elapsed().as_micros()
980 );
981
982 return result;
983 }
984
985 let context_start = Instant::now();
987 let context = self.get_or_create_window_context(spec)?;
988 let context_time = context_start.elapsed();
989
990 let eval_start = Instant::now();
991
992 let result = match name_upper.as_str() {
993 "LAG" => {
994 if args.is_empty() {
996 return Err(anyhow!("LAG requires at least 1 argument"));
997 }
998
999 let column = match &args[0] {
1001 SqlExpression::Column(col) => col.clone(),
1002 _ => return Err(anyhow!("LAG first argument must be a column")),
1003 };
1004
1005 let offset = if args.len() > 1 {
1007 match self.evaluate(&args[1], row_index)? {
1008 DataValue::Integer(i) => i as i32,
1009 _ => return Err(anyhow!("LAG offset must be an integer")),
1010 }
1011 } else {
1012 1
1013 };
1014
1015 let offset_start = Instant::now();
1016 let value = context
1018 .get_offset_value(row_index, -offset, &column.name)
1019 .unwrap_or(DataValue::Null);
1020
1021 debug!(
1022 "LAG offset access took {:.2}μs (offset={})",
1023 offset_start.elapsed().as_micros(),
1024 offset
1025 );
1026
1027 Ok(value)
1028 }
1029 "LEAD" => {
1030 if args.is_empty() {
1032 return Err(anyhow!("LEAD requires at least 1 argument"));
1033 }
1034
1035 let column = match &args[0] {
1037 SqlExpression::Column(col) => col.clone(),
1038 _ => return Err(anyhow!("LEAD first argument must be a column")),
1039 };
1040
1041 let offset = if args.len() > 1 {
1043 match self.evaluate(&args[1], row_index)? {
1044 DataValue::Integer(i) => i as i32,
1045 _ => return Err(anyhow!("LEAD offset must be an integer")),
1046 }
1047 } else {
1048 1
1049 };
1050
1051 let offset_start = Instant::now();
1052 let value = context
1054 .get_offset_value(row_index, offset, &column.name)
1055 .unwrap_or(DataValue::Null);
1056
1057 debug!(
1058 "LEAD offset access took {:.2}μs (offset={})",
1059 offset_start.elapsed().as_micros(),
1060 offset
1061 );
1062
1063 Ok(value)
1064 }
1065 "ROW_NUMBER" => {
1066 Ok(DataValue::Integer(context.get_row_number(row_index) as i64))
1068 }
1069 "RANK" => {
1070 Ok(DataValue::Integer(context.get_rank(row_index)))
1072 }
1073 "DENSE_RANK" => {
1074 Ok(DataValue::Integer(context.get_dense_rank(row_index)))
1076 }
1077 "FIRST_VALUE" => {
1078 if args.is_empty() {
1080 return Err(anyhow!("FIRST_VALUE requires 1 argument"));
1081 }
1082
1083 let column = match &args[0] {
1084 SqlExpression::Column(col) => col.clone(),
1085 _ => return Err(anyhow!("FIRST_VALUE argument must be a column")),
1086 };
1087
1088 if context.has_frame() {
1090 Ok(context
1091 .get_frame_first_value(row_index, &column.name)
1092 .unwrap_or(DataValue::Null))
1093 } else {
1094 Ok(context
1095 .get_first_value(row_index, &column.name)
1096 .unwrap_or(DataValue::Null))
1097 }
1098 }
1099 "LAST_VALUE" => {
1100 if args.is_empty() {
1102 return Err(anyhow!("LAST_VALUE requires 1 argument"));
1103 }
1104
1105 let column = match &args[0] {
1106 SqlExpression::Column(col) => col.clone(),
1107 _ => return Err(anyhow!("LAST_VALUE argument must be a column")),
1108 };
1109
1110 if context.has_frame() {
1112 Ok(context
1113 .get_frame_last_value(row_index, &column.name)
1114 .unwrap_or(DataValue::Null))
1115 } else {
1116 Ok(context
1117 .get_last_value(row_index, &column.name)
1118 .unwrap_or(DataValue::Null))
1119 }
1120 }
1121 "SUM" => {
1122 if args.is_empty() {
1124 return Err(anyhow!("SUM requires 1 argument"));
1125 }
1126
1127 let column = match &args[0] {
1128 SqlExpression::Column(col) => col.clone(),
1129 _ => return Err(anyhow!("SUM argument must be a column")),
1130 };
1131
1132 if context.has_frame() {
1134 Ok(context
1135 .get_frame_sum(row_index, &column.name)
1136 .unwrap_or(DataValue::Null))
1137 } else {
1138 Ok(context
1139 .get_partition_sum(row_index, &column.name)
1140 .unwrap_or(DataValue::Null))
1141 }
1142 }
1143 "AVG" => {
1144 if args.is_empty() {
1146 return Err(anyhow!("AVG requires 1 argument"));
1147 }
1148
1149 let column = match &args[0] {
1150 SqlExpression::Column(col) => col.clone(),
1151 _ => return Err(anyhow!("AVG argument must be a column")),
1152 };
1153
1154 if context.has_frame() {
1156 Ok(context
1157 .get_frame_avg(row_index, &column.name)
1158 .unwrap_or(DataValue::Null))
1159 } else {
1160 Ok(context
1161 .get_partition_avg(row_index, &column.name)
1162 .unwrap_or(DataValue::Null))
1163 }
1164 }
1165 "STDDEV" | "STDEV" => {
1166 if args.is_empty() {
1168 return Err(anyhow!("STDDEV requires 1 argument"));
1169 }
1170
1171 let column = match &args[0] {
1172 SqlExpression::Column(col) => col.clone(),
1173 _ => return Err(anyhow!("STDDEV argument must be a column")),
1174 };
1175
1176 Ok(context
1177 .get_frame_stddev(row_index, &column.name)
1178 .unwrap_or(DataValue::Null))
1179 }
1180 "VARIANCE" | "VAR" => {
1181 if args.is_empty() {
1183 return Err(anyhow!("VARIANCE requires 1 argument"));
1184 }
1185
1186 let column = match &args[0] {
1187 SqlExpression::Column(col) => col.clone(),
1188 _ => return Err(anyhow!("VARIANCE argument must be a column")),
1189 };
1190
1191 Ok(context
1192 .get_frame_variance(row_index, &column.name)
1193 .unwrap_or(DataValue::Null))
1194 }
1195 "MIN" => {
1196 if args.is_empty() {
1198 return Err(anyhow!("MIN requires 1 argument"));
1199 }
1200
1201 let column = match &args[0] {
1202 SqlExpression::Column(col) => col.clone(),
1203 _ => return Err(anyhow!("MIN argument must be a column")),
1204 };
1205
1206 let frame_rows = context.get_frame_rows(row_index);
1207 if frame_rows.is_empty() {
1208 return Ok(DataValue::Null);
1209 }
1210
1211 let source_table = context.source();
1212 let col_idx = source_table
1213 .get_column_index(&column.name)
1214 .ok_or_else(|| anyhow!("Column '{}' not found", column.name))?;
1215
1216 let mut min_value: Option<DataValue> = None;
1217 for &row_idx in &frame_rows {
1218 if let Some(value) = source_table.get_value(row_idx, col_idx) {
1219 if !matches!(value, DataValue::Null) {
1220 match &min_value {
1221 None => min_value = Some(value.clone()),
1222 Some(current_min) => {
1223 if value < current_min {
1224 min_value = Some(value.clone());
1225 }
1226 }
1227 }
1228 }
1229 }
1230 }
1231
1232 Ok(min_value.unwrap_or(DataValue::Null))
1233 }
1234 "MAX" => {
1235 if args.is_empty() {
1237 return Err(anyhow!("MAX requires 1 argument"));
1238 }
1239
1240 let column = match &args[0] {
1241 SqlExpression::Column(col) => col.clone(),
1242 _ => return Err(anyhow!("MAX argument must be a column")),
1243 };
1244
1245 let frame_rows = context.get_frame_rows(row_index);
1246 if frame_rows.is_empty() {
1247 return Ok(DataValue::Null);
1248 }
1249
1250 let source_table = context.source();
1251 let col_idx = source_table
1252 .get_column_index(&column.name)
1253 .ok_or_else(|| anyhow!("Column '{}' not found", column.name))?;
1254
1255 let mut max_value: Option<DataValue> = None;
1256 for &row_idx in &frame_rows {
1257 if let Some(value) = source_table.get_value(row_idx, col_idx) {
1258 if !matches!(value, DataValue::Null) {
1259 match &max_value {
1260 None => max_value = Some(value.clone()),
1261 Some(current_max) => {
1262 if value > current_max {
1263 max_value = Some(value.clone());
1264 }
1265 }
1266 }
1267 }
1268 }
1269 }
1270
1271 Ok(max_value.unwrap_or(DataValue::Null))
1272 }
1273 "COUNT" => {
1274 if args.is_empty() {
1278 if context.has_frame() {
1280 Ok(context
1281 .get_frame_count(row_index, None)
1282 .unwrap_or(DataValue::Null))
1283 } else {
1284 Ok(context
1285 .get_partition_count(row_index, None)
1286 .unwrap_or(DataValue::Null))
1287 }
1288 } else {
1289 let column = match &args[0] {
1291 SqlExpression::Column(col) => {
1292 if col.name == "*" {
1293 if context.has_frame() {
1295 return Ok(context
1296 .get_frame_count(row_index, None)
1297 .unwrap_or(DataValue::Null));
1298 } else {
1299 return Ok(context
1300 .get_partition_count(row_index, None)
1301 .unwrap_or(DataValue::Null));
1302 }
1303 }
1304 col.clone()
1305 }
1306 SqlExpression::StringLiteral(s) if s == "*" => {
1307 if context.has_frame() {
1309 return Ok(context
1310 .get_frame_count(row_index, None)
1311 .unwrap_or(DataValue::Null));
1312 } else {
1313 return Ok(context
1314 .get_partition_count(row_index, None)
1315 .unwrap_or(DataValue::Null));
1316 }
1317 }
1318 _ => return Err(anyhow!("COUNT argument must be a column or *")),
1319 };
1320
1321 if context.has_frame() {
1323 Ok(context
1324 .get_frame_count(row_index, Some(&column.name))
1325 .unwrap_or(DataValue::Null))
1326 } else {
1327 Ok(context
1328 .get_partition_count(row_index, Some(&column.name))
1329 .unwrap_or(DataValue::Null))
1330 }
1331 }
1332 }
1333 _ => Err(anyhow!("Unknown window function: {}", name)),
1334 };
1335
1336 let eval_time = eval_start.elapsed();
1337
1338 info!(
1339 "{} (built-in) evaluation: total={:.2}μs, context={:.2}μs, eval={:.2}μs",
1340 name_upper,
1341 func_start.elapsed().as_micros(),
1342 context_time.as_micros(),
1343 eval_time.as_micros()
1344 );
1345
1346 result
1347 }
1348
1349 fn evaluate_method_call(
1351 &mut self,
1352 object: &str,
1353 method: &str,
1354 args: &[SqlExpression],
1355 row_index: usize,
1356 ) -> Result<DataValue> {
1357 let col_index = self.table.get_column_index(object).ok_or_else(|| {
1359 let suggestion = self.find_similar_column(object);
1360 match suggestion {
1361 Some(similar) => {
1362 anyhow!("Column '{}' not found. Did you mean '{}'?", object, similar)
1363 }
1364 None => anyhow!("Column '{}' not found", object),
1365 }
1366 })?;
1367
1368 let cell_value = self.table.get_value(row_index, col_index).cloned();
1369
1370 self.evaluate_method_on_value(
1371 &cell_value.unwrap_or(DataValue::Null),
1372 method,
1373 args,
1374 row_index,
1375 )
1376 }
1377
1378 fn evaluate_method_on_value(
1380 &mut self,
1381 value: &DataValue,
1382 method: &str,
1383 args: &[SqlExpression],
1384 row_index: usize,
1385 ) -> Result<DataValue> {
1386 let function_name = match method.to_lowercase().as_str() {
1391 "trim" => "TRIM",
1392 "trimstart" | "trimbegin" => "TRIMSTART",
1393 "trimend" => "TRIMEND",
1394 "length" | "len" => "LENGTH",
1395 "contains" => "CONTAINS",
1396 "startswith" => "STARTSWITH",
1397 "endswith" => "ENDSWITH",
1398 "indexof" => "INDEXOF",
1399 _ => method, };
1401
1402 if self.function_registry.get(function_name).is_some() {
1404 debug!(
1405 "Proxying method '{}' through function registry as '{}'",
1406 method, function_name
1407 );
1408
1409 let mut func_args = vec![value.clone()];
1411
1412 for arg in args {
1414 func_args.push(self.evaluate(arg, row_index)?);
1415 }
1416
1417 let func = self.function_registry.get(function_name).unwrap();
1419 return func.evaluate(&func_args);
1420 }
1421
1422 Err(anyhow!(
1425 "Method '{}' not found. It should be registered in the function registry.",
1426 method
1427 ))
1428 }
1429
1430 fn evaluate_case_expression(
1432 &mut self,
1433 when_branches: &[crate::sql::recursive_parser::WhenBranch],
1434 else_branch: &Option<Box<SqlExpression>>,
1435 row_index: usize,
1436 ) -> Result<DataValue> {
1437 debug!(
1438 "ArithmeticEvaluator: evaluating CASE expression for row {}",
1439 row_index
1440 );
1441
1442 for branch in when_branches {
1444 let condition_result = self.evaluate_condition_as_bool(&branch.condition, row_index)?;
1446
1447 if condition_result {
1448 debug!("CASE: WHEN condition matched, evaluating result expression");
1449 return self.evaluate(&branch.result, row_index);
1450 }
1451 }
1452
1453 if let Some(else_expr) = else_branch {
1455 debug!("CASE: No WHEN matched, evaluating ELSE expression");
1456 self.evaluate(else_expr, row_index)
1457 } else {
1458 debug!("CASE: No WHEN matched and no ELSE, returning NULL");
1459 Ok(DataValue::Null)
1460 }
1461 }
1462
1463 fn evaluate_simple_case_expression(
1465 &mut self,
1466 expr: &Box<SqlExpression>,
1467 when_branches: &[crate::sql::parser::ast::SimpleWhenBranch],
1468 else_branch: &Option<Box<SqlExpression>>,
1469 row_index: usize,
1470 ) -> Result<DataValue> {
1471 debug!(
1472 "ArithmeticEvaluator: evaluating simple CASE expression for row {}",
1473 row_index
1474 );
1475
1476 let case_value = self.evaluate(expr, row_index)?;
1478 debug!("Simple CASE: evaluated expression to {:?}", case_value);
1479
1480 for branch in when_branches {
1482 let when_value = self.evaluate(&branch.value, row_index)?;
1484
1485 if self.values_equal(&case_value, &when_value)? {
1487 debug!("Simple CASE: WHEN value matched, evaluating result expression");
1488 return self.evaluate(&branch.result, row_index);
1489 }
1490 }
1491
1492 if let Some(else_expr) = else_branch {
1494 debug!("Simple CASE: No WHEN matched, evaluating ELSE expression");
1495 self.evaluate(else_expr, row_index)
1496 } else {
1497 debug!("Simple CASE: No WHEN matched and no ELSE, returning NULL");
1498 Ok(DataValue::Null)
1499 }
1500 }
1501
1502 fn values_equal(&self, left: &DataValue, right: &DataValue) -> Result<bool> {
1504 match (left, right) {
1505 (DataValue::Null, DataValue::Null) => Ok(true),
1506 (DataValue::Null, _) | (_, DataValue::Null) => Ok(false),
1507 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(a == b),
1508 (DataValue::Float(a), DataValue::Float(b)) => Ok((a - b).abs() < f64::EPSILON),
1509 (DataValue::String(a), DataValue::String(b)) => Ok(a == b),
1510 (DataValue::Boolean(a), DataValue::Boolean(b)) => Ok(a == b),
1511 (DataValue::DateTime(a), DataValue::DateTime(b)) => Ok(a == b),
1512 (DataValue::Integer(a), DataValue::Float(b)) => {
1514 Ok((*a as f64 - b).abs() < f64::EPSILON)
1515 }
1516 (DataValue::Float(a), DataValue::Integer(b)) => {
1517 Ok((a - *b as f64).abs() < f64::EPSILON)
1518 }
1519 _ => Ok(false),
1520 }
1521 }
1522
1523 fn evaluate_condition_as_bool(
1525 &mut self,
1526 expr: &SqlExpression,
1527 row_index: usize,
1528 ) -> Result<bool> {
1529 let value = self.evaluate(expr, row_index)?;
1530
1531 match value {
1532 DataValue::Boolean(b) => Ok(b),
1533 DataValue::Integer(i) => Ok(i != 0),
1534 DataValue::Float(f) => Ok(f != 0.0),
1535 DataValue::Null => Ok(false),
1536 DataValue::String(s) => Ok(!s.is_empty()),
1537 DataValue::InternedString(s) => Ok(!s.is_empty()),
1538 _ => Ok(true), }
1540 }
1541
1542 fn evaluate_datetime_constructor(
1544 &self,
1545 year: i32,
1546 month: u32,
1547 day: u32,
1548 hour: Option<u32>,
1549 minute: Option<u32>,
1550 second: Option<u32>,
1551 ) -> Result<DataValue> {
1552 use chrono::{NaiveDate, TimeZone, Utc};
1553
1554 let date = NaiveDate::from_ymd_opt(year, month, day)
1556 .ok_or_else(|| anyhow!("Invalid date: {}-{}-{}", year, month, day))?;
1557
1558 let hour = hour.unwrap_or(0);
1560 let minute = minute.unwrap_or(0);
1561 let second = second.unwrap_or(0);
1562
1563 let naive_datetime = date
1564 .and_hms_opt(hour, minute, second)
1565 .ok_or_else(|| anyhow!("Invalid time: {}:{}:{}", hour, minute, second))?;
1566
1567 let datetime = Utc.from_utc_datetime(&naive_datetime);
1569
1570 let datetime_str = datetime.format("%Y-%m-%d %H:%M:%S%.3f").to_string();
1572 Ok(DataValue::String(datetime_str))
1573 }
1574
1575 fn evaluate_datetime_today(
1577 &self,
1578 hour: Option<u32>,
1579 minute: Option<u32>,
1580 second: Option<u32>,
1581 ) -> Result<DataValue> {
1582 use chrono::{TimeZone, Utc};
1583
1584 let today = Utc::now().date_naive();
1586
1587 let hour = hour.unwrap_or(0);
1589 let minute = minute.unwrap_or(0);
1590 let second = second.unwrap_or(0);
1591
1592 let naive_datetime = today
1593 .and_hms_opt(hour, minute, second)
1594 .ok_or_else(|| anyhow!("Invalid time: {}:{}:{}", hour, minute, second))?;
1595
1596 let datetime = Utc.from_utc_datetime(&naive_datetime);
1598
1599 let datetime_str = datetime.format("%Y-%m-%d %H:%M:%S%.3f").to_string();
1601 Ok(DataValue::String(datetime_str))
1602 }
1603}
1604
1605#[cfg(test)]
1606mod tests {
1607 use super::*;
1608 use crate::data::datatable::{DataColumn, DataRow};
1609
1610 fn create_test_table() -> DataTable {
1611 let mut table = DataTable::new("test");
1612 table.add_column(DataColumn::new("a"));
1613 table.add_column(DataColumn::new("b"));
1614 table.add_column(DataColumn::new("c"));
1615
1616 table
1617 .add_row(DataRow::new(vec![
1618 DataValue::Integer(10),
1619 DataValue::Float(2.5),
1620 DataValue::Integer(4),
1621 ]))
1622 .unwrap();
1623
1624 table
1625 }
1626
1627 #[test]
1628 fn test_evaluate_column() {
1629 let table = create_test_table();
1630 let mut evaluator = ArithmeticEvaluator::new(&table);
1631
1632 let expr = SqlExpression::Column(ColumnRef::unquoted("a".to_string()));
1633 let result = evaluator.evaluate(&expr, 0).unwrap();
1634 assert_eq!(result, DataValue::Integer(10));
1635 }
1636
1637 #[test]
1638 fn test_evaluate_number_literal() {
1639 let table = create_test_table();
1640 let mut evaluator = ArithmeticEvaluator::new(&table);
1641
1642 let expr = SqlExpression::NumberLiteral("42".to_string());
1643 let result = evaluator.evaluate(&expr, 0).unwrap();
1644 assert_eq!(result, DataValue::Integer(42));
1645
1646 let expr = SqlExpression::NumberLiteral("3.14".to_string());
1647 let result = evaluator.evaluate(&expr, 0).unwrap();
1648 assert_eq!(result, DataValue::Float(3.14));
1649 }
1650
1651 #[test]
1652 fn test_add_values() {
1653 let table = create_test_table();
1654 let mut evaluator = ArithmeticEvaluator::new(&table);
1655
1656 let result = evaluator
1658 .add_values(&DataValue::Integer(5), &DataValue::Integer(3))
1659 .unwrap();
1660 assert_eq!(result, DataValue::Integer(8));
1661
1662 let result = evaluator
1664 .add_values(&DataValue::Integer(5), &DataValue::Float(2.5))
1665 .unwrap();
1666 assert_eq!(result, DataValue::Float(7.5));
1667 }
1668
1669 #[test]
1670 fn test_multiply_values() {
1671 let table = create_test_table();
1672 let mut evaluator = ArithmeticEvaluator::new(&table);
1673
1674 let result = evaluator
1676 .multiply_values(&DataValue::Integer(4), &DataValue::Float(2.5))
1677 .unwrap();
1678 assert_eq!(result, DataValue::Float(10.0));
1679 }
1680
1681 #[test]
1682 fn test_divide_values() {
1683 let table = create_test_table();
1684 let mut evaluator = ArithmeticEvaluator::new(&table);
1685
1686 let result = evaluator
1688 .divide_values(&DataValue::Integer(10), &DataValue::Integer(2))
1689 .unwrap();
1690 assert_eq!(result, DataValue::Integer(5));
1691
1692 let result = evaluator
1694 .divide_values(&DataValue::Integer(10), &DataValue::Integer(3))
1695 .unwrap();
1696 assert_eq!(result, DataValue::Float(10.0 / 3.0));
1697 }
1698
1699 #[test]
1700 fn test_division_by_zero() {
1701 let table = create_test_table();
1702 let mut evaluator = ArithmeticEvaluator::new(&table);
1703
1704 let result = evaluator.divide_values(&DataValue::Integer(10), &DataValue::Integer(0));
1705 assert!(result.is_err());
1706 assert!(result.unwrap_err().to_string().contains("Division by zero"));
1707 }
1708
1709 #[test]
1710 fn test_binary_op_expression() {
1711 let table = create_test_table();
1712 let mut evaluator = ArithmeticEvaluator::new(&table);
1713
1714 let expr = SqlExpression::BinaryOp {
1716 left: Box::new(SqlExpression::Column(ColumnRef::unquoted("a".to_string()))),
1717 op: "*".to_string(),
1718 right: Box::new(SqlExpression::Column(ColumnRef::unquoted("b".to_string()))),
1719 };
1720
1721 let result = evaluator.evaluate(&expr, 0).unwrap();
1722 assert_eq!(result, DataValue::Float(25.0));
1723 }
1724}