1use crate::config::global::get_date_notation;
2use crate::data::data_view::DataView;
3use crate::data::datatable::{DataTable, DataValue};
4use crate::data::value_comparisons::compare_with_op;
5use crate::sql::aggregate_functions::AggregateFunctionRegistry; use crate::sql::aggregates::AggregateRegistry; use crate::sql::functions::FunctionRegistry;
8use crate::sql::parser::ast::{ColumnRef, WindowSpec};
9use crate::sql::recursive_parser::SqlExpression;
10use crate::sql::window_context::WindowContext;
11use crate::sql::window_functions::{ExpressionEvaluator, WindowFunctionRegistry};
12use anyhow::{anyhow, Result};
13use std::collections::{HashMap, HashSet};
14use std::sync::Arc;
15use tracing::debug;
16
17pub struct ArithmeticEvaluator<'a> {
20 table: &'a DataTable,
21 _date_notation: String,
22 function_registry: Arc<FunctionRegistry>,
23 aggregate_registry: Arc<AggregateRegistry>, new_aggregate_registry: Arc<AggregateFunctionRegistry>, window_function_registry: Arc<WindowFunctionRegistry>,
26 visible_rows: Option<Vec<usize>>, window_contexts: HashMap<String, Arc<WindowContext>>, table_aliases: HashMap<String, String>, }
30
31impl<'a> ArithmeticEvaluator<'a> {
32 #[must_use]
33 pub fn new(table: &'a DataTable) -> Self {
34 Self {
35 table,
36 _date_notation: get_date_notation(),
37 function_registry: Arc::new(FunctionRegistry::new()),
38 aggregate_registry: Arc::new(AggregateRegistry::new()),
39 new_aggregate_registry: Arc::new(AggregateFunctionRegistry::new()),
40 window_function_registry: Arc::new(WindowFunctionRegistry::new()),
41 visible_rows: None,
42 window_contexts: HashMap::new(),
43 table_aliases: HashMap::new(),
44 }
45 }
46
47 #[must_use]
48 pub fn with_date_notation(table: &'a DataTable, date_notation: String) -> Self {
49 Self {
50 table,
51 _date_notation: date_notation,
52 function_registry: Arc::new(FunctionRegistry::new()),
53 aggregate_registry: Arc::new(AggregateRegistry::new()),
54 new_aggregate_registry: Arc::new(AggregateFunctionRegistry::new()),
55 window_function_registry: Arc::new(WindowFunctionRegistry::new()),
56 visible_rows: None,
57 window_contexts: HashMap::new(),
58 table_aliases: HashMap::new(),
59 }
60 }
61
62 #[must_use]
64 pub fn with_visible_rows(mut self, rows: Vec<usize>) -> Self {
65 self.visible_rows = Some(rows);
66 self
67 }
68
69 #[must_use]
71 pub fn with_table_aliases(mut self, aliases: HashMap<String, String>) -> Self {
72 self.table_aliases = aliases;
73 self
74 }
75
76 #[must_use]
77 pub fn with_date_notation_and_registry(
78 table: &'a DataTable,
79 date_notation: String,
80 function_registry: Arc<FunctionRegistry>,
81 ) -> Self {
82 Self {
83 table,
84 _date_notation: date_notation,
85 function_registry,
86 aggregate_registry: Arc::new(AggregateRegistry::new()),
87 new_aggregate_registry: Arc::new(AggregateFunctionRegistry::new()),
88 window_function_registry: Arc::new(WindowFunctionRegistry::new()),
89 visible_rows: None,
90 window_contexts: HashMap::new(),
91 table_aliases: HashMap::new(),
92 }
93 }
94
95 fn find_similar_column(&self, name: &str) -> Option<String> {
97 let columns = self.table.column_names();
98 let mut best_match: Option<(String, usize)> = None;
99
100 for col in columns {
101 let distance = self.edit_distance(&col.to_lowercase(), &name.to_lowercase());
102 let max_distance = if name.len() > 10 { 3 } else { 2 };
105 if distance <= max_distance {
106 match &best_match {
107 None => best_match = Some((col, distance)),
108 Some((_, best_dist)) if distance < *best_dist => {
109 best_match = Some((col, distance));
110 }
111 _ => {}
112 }
113 }
114 }
115
116 best_match.map(|(name, _)| name)
117 }
118
119 fn edit_distance(&self, s1: &str, s2: &str) -> usize {
121 crate::sql::functions::string_methods::EditDistanceFunction::calculate_edit_distance(s1, s2)
123 }
124
125 pub fn evaluate(&mut self, expr: &SqlExpression, row_index: usize) -> Result<DataValue> {
127 debug!(
128 "ArithmeticEvaluator: evaluating {:?} for row {}",
129 expr, row_index
130 );
131
132 match expr {
133 SqlExpression::Column(column_ref) => self.evaluate_column_ref(column_ref, row_index),
134 SqlExpression::StringLiteral(s) => Ok(DataValue::String(s.clone())),
135 SqlExpression::BooleanLiteral(b) => Ok(DataValue::Boolean(*b)),
136 SqlExpression::NumberLiteral(n) => self.evaluate_number_literal(n),
137 SqlExpression::Null => Ok(DataValue::Null),
138 SqlExpression::BinaryOp { left, op, right } => {
139 self.evaluate_binary_op(left, op, right, row_index)
140 }
141 SqlExpression::FunctionCall {
142 name,
143 args,
144 distinct,
145 } => self.evaluate_function_with_distinct(name, args, *distinct, row_index),
146 SqlExpression::WindowFunction {
147 name,
148 args,
149 window_spec,
150 } => self.evaluate_window_function(name, args, window_spec, row_index),
151 SqlExpression::MethodCall {
152 object,
153 method,
154 args,
155 } => self.evaluate_method_call(object, method, args, row_index),
156 SqlExpression::ChainedMethodCall { base, method, args } => {
157 let base_value = self.evaluate(base, row_index)?;
159 self.evaluate_method_on_value(&base_value, method, args, row_index)
160 }
161 SqlExpression::CaseExpression {
162 when_branches,
163 else_branch,
164 } => self.evaluate_case_expression(when_branches, else_branch, row_index),
165 SqlExpression::SimpleCaseExpression {
166 expr,
167 when_branches,
168 else_branch,
169 } => self.evaluate_simple_case_expression(expr, when_branches, else_branch, row_index),
170 SqlExpression::DateTimeConstructor {
171 year,
172 month,
173 day,
174 hour,
175 minute,
176 second,
177 } => self.evaluate_datetime_constructor(*year, *month, *day, *hour, *minute, *second),
178 SqlExpression::DateTimeToday {
179 hour,
180 minute,
181 second,
182 } => self.evaluate_datetime_today(*hour, *minute, *second),
183 _ => Err(anyhow!(
184 "Unsupported expression type for arithmetic evaluation: {:?}",
185 expr
186 )),
187 }
188 }
189
190 fn evaluate_column_ref(&self, column_ref: &ColumnRef, row_index: usize) -> Result<DataValue> {
192 if let Some(table_prefix) = &column_ref.table_prefix {
193 let qualified_name = format!("{}.{}", table_prefix, column_ref.name);
195
196 if let Some(col_idx) = self.table.find_column_by_qualified_name(&qualified_name) {
197 return self
198 .table
199 .get_value(row_index, col_idx)
200 .ok_or_else(|| anyhow!("Row {} out of bounds", row_index))
201 .map(|v| v.clone());
202 }
203
204 Err(anyhow!("Column '{}' not found", qualified_name))
206 } else {
207 self.evaluate_column(&column_ref.name, row_index)
209 }
210 }
211
212 fn evaluate_column(&self, column_name: &str, row_index: usize) -> Result<DataValue> {
214 let resolved_column = if column_name.contains('.') {
216 if let Some(dot_pos) = column_name.rfind('.') {
218 let _table_or_alias = &column_name[..dot_pos];
219 let col_name = &column_name[dot_pos + 1..];
220
221 debug!(
224 "Resolving qualified column: {} -> {}",
225 column_name, col_name
226 );
227 col_name.to_string()
228 } else {
229 column_name.to_string()
230 }
231 } else {
232 column_name.to_string()
233 };
234
235 let col_index = if let Some(idx) = self.table.get_column_index(&resolved_column) {
236 idx
237 } else if resolved_column != column_name {
238 if let Some(idx) = self.table.get_column_index(column_name) {
240 idx
241 } else {
242 let suggestion = self.find_similar_column(&resolved_column);
243 return Err(match suggestion {
244 Some(similar) => anyhow!(
245 "Column '{}' not found. Did you mean '{}'?",
246 column_name,
247 similar
248 ),
249 None => anyhow!("Column '{}' not found", column_name),
250 });
251 }
252 } else {
253 let suggestion = self.find_similar_column(&resolved_column);
254 return Err(match suggestion {
255 Some(similar) => anyhow!(
256 "Column '{}' not found. Did you mean '{}'?",
257 column_name,
258 similar
259 ),
260 None => anyhow!("Column '{}' not found", column_name),
261 });
262 };
263
264 if row_index >= self.table.row_count() {
265 return Err(anyhow!("Row index {} out of bounds", row_index));
266 }
267
268 let row = self
269 .table
270 .get_row(row_index)
271 .ok_or_else(|| anyhow!("Row {} not found", row_index))?;
272
273 let value = row
274 .get(col_index)
275 .ok_or_else(|| anyhow!("Column index {} out of bounds for row", col_index))?;
276
277 Ok(value.clone())
278 }
279
280 fn evaluate_number_literal(&self, number_str: &str) -> Result<DataValue> {
282 if let Ok(int_val) = number_str.parse::<i64>() {
284 return Ok(DataValue::Integer(int_val));
285 }
286
287 if let Ok(float_val) = number_str.parse::<f64>() {
289 return Ok(DataValue::Float(float_val));
290 }
291
292 Err(anyhow!("Invalid number literal: {}", number_str))
293 }
294
295 fn evaluate_binary_op(
297 &mut self,
298 left: &SqlExpression,
299 op: &str,
300 right: &SqlExpression,
301 row_index: usize,
302 ) -> Result<DataValue> {
303 let left_val = self.evaluate(left, row_index)?;
304 let right_val = self.evaluate(right, row_index)?;
305
306 debug!(
307 "ArithmeticEvaluator: {} {} {}",
308 self.format_value(&left_val),
309 op,
310 self.format_value(&right_val)
311 );
312
313 match op {
314 "+" => self.add_values(&left_val, &right_val),
315 "-" => self.subtract_values(&left_val, &right_val),
316 "*" => self.multiply_values(&left_val, &right_val),
317 "/" => self.divide_values(&left_val, &right_val),
318 "%" => {
319 let args = vec![left.clone(), right.clone()];
321 self.evaluate_function("MOD", &args, row_index)
322 }
323 ">" | "<" | ">=" | "<=" | "=" | "!=" | "<>" => {
326 let result = compare_with_op(&left_val, &right_val, op, false);
327 Ok(DataValue::Boolean(result))
328 }
329 "IS NULL" => Ok(DataValue::Boolean(matches!(left_val, DataValue::Null))),
331 "IS NOT NULL" => Ok(DataValue::Boolean(!matches!(left_val, DataValue::Null))),
332 "AND" => {
334 let left_bool = self.to_bool(&left_val)?;
335 let right_bool = self.to_bool(&right_val)?;
336 Ok(DataValue::Boolean(left_bool && right_bool))
337 }
338 "OR" => {
339 let left_bool = self.to_bool(&left_val)?;
340 let right_bool = self.to_bool(&right_val)?;
341 Ok(DataValue::Boolean(left_bool || right_bool))
342 }
343 _ => Err(anyhow!("Unsupported arithmetic operator: {}", op)),
344 }
345 }
346
347 fn add_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
349 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
351 return Ok(DataValue::Null);
352 }
353
354 match (left, right) {
355 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a + b)),
356 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 + b)),
357 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a + *b as f64)),
358 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a + b)),
359 _ => Err(anyhow!("Cannot add {:?} and {:?}", left, right)),
360 }
361 }
362
363 fn subtract_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
365 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
367 return Ok(DataValue::Null);
368 }
369
370 match (left, right) {
371 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a - b)),
372 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 - b)),
373 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a - *b as f64)),
374 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a - b)),
375 _ => Err(anyhow!("Cannot subtract {:?} and {:?}", left, right)),
376 }
377 }
378
379 fn multiply_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
381 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
383 return Ok(DataValue::Null);
384 }
385
386 match (left, right) {
387 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a * b)),
388 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 * b)),
389 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a * *b as f64)),
390 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a * b)),
391 _ => Err(anyhow!("Cannot multiply {:?} and {:?}", left, right)),
392 }
393 }
394
395 fn divide_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
397 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
399 return Ok(DataValue::Null);
400 }
401
402 let is_zero = match right {
404 DataValue::Integer(0) => true,
405 DataValue::Float(f) if *f == 0.0 => true, _ => false,
407 };
408
409 if is_zero {
410 return Err(anyhow!("Division by zero"));
411 }
412
413 match (left, right) {
414 (DataValue::Integer(a), DataValue::Integer(b)) => {
415 if a % b == 0 {
417 Ok(DataValue::Integer(a / b))
418 } else {
419 Ok(DataValue::Float(*a as f64 / *b as f64))
420 }
421 }
422 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 / b)),
423 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a / *b as f64)),
424 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a / b)),
425 _ => Err(anyhow!("Cannot divide {:?} and {:?}", left, right)),
426 }
427 }
428
429 fn format_value(&self, value: &DataValue) -> String {
431 match value {
432 DataValue::Integer(i) => i.to_string(),
433 DataValue::Float(f) => f.to_string(),
434 DataValue::String(s) => format!("'{s}'"),
435 _ => format!("{value:?}"),
436 }
437 }
438
439 fn to_bool(&self, value: &DataValue) -> Result<bool> {
441 match value {
442 DataValue::Boolean(b) => Ok(*b),
443 DataValue::Integer(i) => Ok(*i != 0),
444 DataValue::Float(f) => Ok(*f != 0.0),
445 DataValue::Null => Ok(false),
446 _ => Err(anyhow!("Cannot convert {:?} to boolean", value)),
447 }
448 }
449
450 fn evaluate_function_with_distinct(
452 &mut self,
453 name: &str,
454 args: &[SqlExpression],
455 distinct: bool,
456 row_index: usize,
457 ) -> Result<DataValue> {
458 if distinct {
460 let name_upper = name.to_uppercase();
461
462 if self.aggregate_registry.is_aggregate(&name_upper)
464 || self.new_aggregate_registry.contains(&name_upper)
465 {
466 return self.evaluate_aggregate_with_distinct(&name_upper, args, row_index);
467 } else {
468 return Err(anyhow!(
469 "DISTINCT can only be used with aggregate functions"
470 ));
471 }
472 }
473
474 self.evaluate_function(name, args, row_index)
476 }
477
478 fn evaluate_aggregate_with_distinct(
479 &mut self,
480 name: &str,
481 args: &[SqlExpression],
482 _row_index: usize,
483 ) -> Result<DataValue> {
484 let name_upper = name.to_uppercase();
485
486 if self.new_aggregate_registry.get(&name_upper).is_some() {
488 let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
489 visible.clone()
490 } else {
491 (0..self.table.rows.len()).collect()
492 };
493
494 let mut vals = Vec::new();
496 for &row_idx in &rows_to_process {
497 if !args.is_empty() {
498 let value = self.evaluate(&args[0], row_idx)?;
499 vals.push(value);
500 }
501 }
502
503 let mut seen = HashSet::new();
505 let unique_values: Vec<_> = vals
506 .into_iter()
507 .filter(|v| {
508 let key = format!("{:?}", v);
509 seen.insert(key)
510 })
511 .collect();
512
513 let agg_func = self.new_aggregate_registry.get(&name_upper).unwrap();
515 let mut state = agg_func.create_state();
516
517 for value in &unique_values {
519 state.accumulate(value)?;
520 }
521
522 return Ok(state.finalize());
523 }
524
525 if self.aggregate_registry.get(&name_upper).is_some() {
527 let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
529 visible.clone()
530 } else {
531 (0..self.table.rows.len()).collect()
532 };
533
534 if name_upper == "STRING_AGG" && args.len() >= 2 {
536 let mut state = crate::sql::aggregates::AggregateState::StringAgg(
538 if args.len() >= 2 {
540 let separator = self.evaluate(&args[1], 0)?; match separator {
542 DataValue::String(s) => crate::sql::aggregates::StringAggState::new(&s),
543 DataValue::InternedString(s) => {
544 crate::sql::aggregates::StringAggState::new(&s)
545 }
546 _ => crate::sql::aggregates::StringAggState::new(","), }
548 } else {
549 crate::sql::aggregates::StringAggState::new(",")
550 },
551 );
552
553 let mut seen_values = HashSet::new();
556
557 for &row_idx in &rows_to_process {
558 let value = self.evaluate(&args[0], row_idx)?;
559
560 if !seen_values.insert(value.clone()) {
562 continue; }
564
565 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
567 agg_func.accumulate(&mut state, &value)?;
568 }
569
570 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
572 return Ok(agg_func.finalize(state));
573 }
574
575 let mut vals = Vec::new();
578 for &row_idx in &rows_to_process {
579 if !args.is_empty() {
580 let value = self.evaluate(&args[0], row_idx)?;
581 vals.push(value);
582 }
583 }
584
585 let mut seen = HashSet::new();
587 let mut unique_values = Vec::new();
588 for value in vals {
589 if seen.insert(value.clone()) {
590 unique_values.push(value);
591 }
592 }
593
594 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
596 let mut state = agg_func.init();
597
598 for value in &unique_values {
600 agg_func.accumulate(&mut state, value)?;
601 }
602
603 return Ok(agg_func.finalize(state));
604 }
605
606 Err(anyhow!("Unknown aggregate function: {}", name))
607 }
608
609 fn evaluate_function(
610 &mut self,
611 name: &str,
612 args: &[SqlExpression],
613 row_index: usize,
614 ) -> Result<DataValue> {
615 let name_upper = name.to_uppercase();
617
618 if self.new_aggregate_registry.get(&name_upper).is_some() {
620 let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
622 visible.clone()
623 } else {
624 (0..self.table.rows.len()).collect()
625 };
626
627 let agg_func = self.new_aggregate_registry.get(&name_upper).unwrap();
629 let mut state = agg_func.create_state();
630
631 if name_upper == "COUNT" || name_upper == "COUNT_STAR" {
633 if args.is_empty()
634 || (args.len() == 1
635 && matches!(&args[0], SqlExpression::Column(col) if col.name == "*"))
636 || (args.len() == 1
637 && matches!(&args[0], SqlExpression::StringLiteral(s) if s == "*"))
638 {
639 for _ in &rows_to_process {
641 state.accumulate(&DataValue::Integer(1))?;
642 }
643 } else {
644 for &row_idx in &rows_to_process {
646 let value = self.evaluate(&args[0], row_idx)?;
647 state.accumulate(&value)?;
648 }
649 }
650 } else {
651 if !args.is_empty() {
653 for &row_idx in &rows_to_process {
654 let value = self.evaluate(&args[0], row_idx)?;
655 state.accumulate(&value)?;
656 }
657 }
658 }
659
660 return Ok(state.finalize());
661 }
662
663 if self.aggregate_registry.get(&name_upper).is_some() {
665 let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
667 visible.clone()
668 } else {
669 (0..self.table.rows.len()).collect()
670 };
671
672 if name_upper == "STRING_AGG" && args.len() >= 2 {
674 let mut state = crate::sql::aggregates::AggregateState::StringAgg(
676 if args.len() >= 2 {
678 let separator = self.evaluate(&args[1], 0)?; match separator {
680 DataValue::String(s) => crate::sql::aggregates::StringAggState::new(&s),
681 DataValue::InternedString(s) => {
682 crate::sql::aggregates::StringAggState::new(&s)
683 }
684 _ => crate::sql::aggregates::StringAggState::new(","), }
686 } else {
687 crate::sql::aggregates::StringAggState::new(",")
688 },
689 );
690
691 for &row_idx in &rows_to_process {
693 let value = self.evaluate(&args[0], row_idx)?;
694 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
696 agg_func.accumulate(&mut state, &value)?;
697 }
698
699 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
701 return Ok(agg_func.finalize(state));
702 }
703
704 let values = if !args.is_empty()
706 && !(args.len() == 1
707 && matches!(&args[0], SqlExpression::Column(c) if c.name == "*"))
708 {
709 let mut vals = Vec::new();
711 for &row_idx in &rows_to_process {
712 let value = self.evaluate(&args[0], row_idx)?;
713 vals.push(value);
714 }
715 Some(vals)
716 } else {
717 None
718 };
719
720 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
722 let mut state = agg_func.init();
723
724 if let Some(values) = values {
725 for value in &values {
727 agg_func.accumulate(&mut state, value)?;
728 }
729 } else {
730 for _ in &rows_to_process {
732 agg_func.accumulate(&mut state, &DataValue::Integer(1))?;
733 }
734 }
735
736 return Ok(agg_func.finalize(state));
737 }
738
739 if self.function_registry.get(name).is_some() {
741 let mut evaluated_args = Vec::new();
743 for arg in args {
744 evaluated_args.push(self.evaluate(arg, row_index)?);
745 }
746
747 let func = self.function_registry.get(name).unwrap();
749 return func.evaluate(&evaluated_args);
750 }
751
752 Err(anyhow!("Unknown function: {}", name))
754 }
755
756 fn get_or_create_window_context(&mut self, spec: &WindowSpec) -> Result<Arc<WindowContext>> {
758 let key = format!("{:?}", spec);
760
761 if let Some(context) = self.window_contexts.get(&key) {
762 return Ok(Arc::clone(context));
763 }
764
765 let data_view = if let Some(ref _visible_rows) = self.visible_rows {
767 let view = DataView::new(Arc::new(self.table.clone()));
769 view
772 } else {
773 DataView::new(Arc::new(self.table.clone()))
774 };
775
776 let context = WindowContext::new_with_spec(Arc::new(data_view), spec.clone())?;
778
779 let context = Arc::new(context);
780 self.window_contexts.insert(key, Arc::clone(&context));
781 Ok(context)
782 }
783
784 fn evaluate_window_function(
786 &mut self,
787 name: &str,
788 args: &[SqlExpression],
789 spec: &WindowSpec,
790 row_index: usize,
791 ) -> Result<DataValue> {
792 let name_upper = name.to_uppercase();
793
794 debug!("Looking for window function {} in registry", name_upper);
796 if let Some(window_fn_arc) = self.window_function_registry.get(&name_upper) {
797 debug!("Found window function {} in registry", name_upper);
798
799 let window_fn = window_fn_arc.as_ref();
801
802 window_fn.validate_args(args)?;
804
805 let transformed_spec = window_fn.transform_window_spec(spec, args)?;
807
808 let context = self.get_or_create_window_context(&transformed_spec)?;
810
811 struct EvaluatorAdapter<'a, 'b> {
813 evaluator: &'a mut ArithmeticEvaluator<'b>,
814 row_index: usize,
815 }
816
817 impl<'a, 'b> ExpressionEvaluator for EvaluatorAdapter<'a, 'b> {
818 fn evaluate(
819 &mut self,
820 expr: &SqlExpression,
821 _row_index: usize,
822 ) -> Result<DataValue> {
823 self.evaluator.evaluate(expr, self.row_index)
824 }
825 }
826
827 let mut adapter = EvaluatorAdapter {
828 evaluator: self,
829 row_index,
830 };
831
832 return window_fn.compute(&context, row_index, args, &mut adapter);
834 }
835
836 let context = self.get_or_create_window_context(spec)?;
838
839 match name_upper.as_str() {
840 "LAG" => {
841 if args.is_empty() {
843 return Err(anyhow!("LAG requires at least 1 argument"));
844 }
845
846 let column = match &args[0] {
848 SqlExpression::Column(col) => col.clone(),
849 _ => return Err(anyhow!("LAG first argument must be a column")),
850 };
851
852 let offset = if args.len() > 1 {
854 match self.evaluate(&args[1], row_index)? {
855 DataValue::Integer(i) => i as i32,
856 _ => return Err(anyhow!("LAG offset must be an integer")),
857 }
858 } else {
859 1
860 };
861
862 Ok(context
864 .get_offset_value(row_index, -offset, &column.name)
865 .unwrap_or(DataValue::Null))
866 }
867 "LEAD" => {
868 if args.is_empty() {
870 return Err(anyhow!("LEAD requires at least 1 argument"));
871 }
872
873 let column = match &args[0] {
875 SqlExpression::Column(col) => col.clone(),
876 _ => return Err(anyhow!("LEAD first argument must be a column")),
877 };
878
879 let offset = if args.len() > 1 {
881 match self.evaluate(&args[1], row_index)? {
882 DataValue::Integer(i) => i as i32,
883 _ => return Err(anyhow!("LEAD offset must be an integer")),
884 }
885 } else {
886 1
887 };
888
889 Ok(context
891 .get_offset_value(row_index, offset, &column.name)
892 .unwrap_or(DataValue::Null))
893 }
894 "ROW_NUMBER" => {
895 Ok(DataValue::Integer(context.get_row_number(row_index) as i64))
897 }
898 "FIRST_VALUE" => {
899 if args.is_empty() {
901 return Err(anyhow!("FIRST_VALUE requires 1 argument"));
902 }
903
904 let column = match &args[0] {
905 SqlExpression::Column(col) => col.clone(),
906 _ => return Err(anyhow!("FIRST_VALUE argument must be a column")),
907 };
908
909 if context.has_frame() {
911 Ok(context
912 .get_frame_first_value(row_index, &column.name)
913 .unwrap_or(DataValue::Null))
914 } else {
915 Ok(context
916 .get_first_value(row_index, &column.name)
917 .unwrap_or(DataValue::Null))
918 }
919 }
920 "LAST_VALUE" => {
921 if args.is_empty() {
923 return Err(anyhow!("LAST_VALUE requires 1 argument"));
924 }
925
926 let column = match &args[0] {
927 SqlExpression::Column(col) => col.clone(),
928 _ => return Err(anyhow!("LAST_VALUE argument must be a column")),
929 };
930
931 if context.has_frame() {
933 Ok(context
934 .get_frame_last_value(row_index, &column.name)
935 .unwrap_or(DataValue::Null))
936 } else {
937 Ok(context
938 .get_last_value(row_index, &column.name)
939 .unwrap_or(DataValue::Null))
940 }
941 }
942 "SUM" => {
943 if args.is_empty() {
945 return Err(anyhow!("SUM requires 1 argument"));
946 }
947
948 let column = match &args[0] {
949 SqlExpression::Column(col) => col.clone(),
950 _ => return Err(anyhow!("SUM argument must be a column")),
951 };
952
953 if context.has_frame() {
955 Ok(context
956 .get_frame_sum(row_index, &column.name)
957 .unwrap_or(DataValue::Null))
958 } else {
959 Ok(context
960 .get_partition_sum(row_index, &column.name)
961 .unwrap_or(DataValue::Null))
962 }
963 }
964 "AVG" => {
965 if args.is_empty() {
967 return Err(anyhow!("AVG requires 1 argument"));
968 }
969
970 let column = match &args[0] {
971 SqlExpression::Column(col) => col.clone(),
972 _ => return Err(anyhow!("AVG argument must be a column")),
973 };
974
975 Ok(context
976 .get_frame_avg(row_index, &column.name)
977 .unwrap_or(DataValue::Null))
978 }
979 "STDDEV" | "STDEV" => {
980 if args.is_empty() {
982 return Err(anyhow!("STDDEV requires 1 argument"));
983 }
984
985 let column = match &args[0] {
986 SqlExpression::Column(col) => col.clone(),
987 _ => return Err(anyhow!("STDDEV argument must be a column")),
988 };
989
990 Ok(context
991 .get_frame_stddev(row_index, &column.name)
992 .unwrap_or(DataValue::Null))
993 }
994 "VARIANCE" | "VAR" => {
995 if args.is_empty() {
997 return Err(anyhow!("VARIANCE requires 1 argument"));
998 }
999
1000 let column = match &args[0] {
1001 SqlExpression::Column(col) => col.clone(),
1002 _ => return Err(anyhow!("VARIANCE argument must be a column")),
1003 };
1004
1005 Ok(context
1006 .get_frame_variance(row_index, &column.name)
1007 .unwrap_or(DataValue::Null))
1008 }
1009 "MIN" => {
1010 if args.is_empty() {
1012 return Err(anyhow!("MIN requires 1 argument"));
1013 }
1014
1015 let column = match &args[0] {
1016 SqlExpression::Column(col) => col.clone(),
1017 _ => return Err(anyhow!("MIN argument must be a column")),
1018 };
1019
1020 let frame_rows = context.get_frame_rows(row_index);
1021 if frame_rows.is_empty() {
1022 return Ok(DataValue::Null);
1023 }
1024
1025 let source_table = context.source();
1026 let col_idx = source_table
1027 .get_column_index(&column.name)
1028 .ok_or_else(|| anyhow!("Column '{}' not found", column.name))?;
1029
1030 let mut min_value: Option<DataValue> = None;
1031 for &row_idx in &frame_rows {
1032 if let Some(value) = source_table.get_value(row_idx, col_idx) {
1033 if !matches!(value, DataValue::Null) {
1034 match &min_value {
1035 None => min_value = Some(value.clone()),
1036 Some(current_min) => {
1037 if value < current_min {
1038 min_value = Some(value.clone());
1039 }
1040 }
1041 }
1042 }
1043 }
1044 }
1045
1046 Ok(min_value.unwrap_or(DataValue::Null))
1047 }
1048 "MAX" => {
1049 if args.is_empty() {
1051 return Err(anyhow!("MAX requires 1 argument"));
1052 }
1053
1054 let column = match &args[0] {
1055 SqlExpression::Column(col) => col.clone(),
1056 _ => return Err(anyhow!("MAX argument must be a column")),
1057 };
1058
1059 let frame_rows = context.get_frame_rows(row_index);
1060 if frame_rows.is_empty() {
1061 return Ok(DataValue::Null);
1062 }
1063
1064 let source_table = context.source();
1065 let col_idx = source_table
1066 .get_column_index(&column.name)
1067 .ok_or_else(|| anyhow!("Column '{}' not found", column.name))?;
1068
1069 let mut max_value: Option<DataValue> = None;
1070 for &row_idx in &frame_rows {
1071 if let Some(value) = source_table.get_value(row_idx, col_idx) {
1072 if !matches!(value, DataValue::Null) {
1073 match &max_value {
1074 None => max_value = Some(value.clone()),
1075 Some(current_max) => {
1076 if value > current_max {
1077 max_value = Some(value.clone());
1078 }
1079 }
1080 }
1081 }
1082 }
1083 }
1084
1085 Ok(max_value.unwrap_or(DataValue::Null))
1086 }
1087 "COUNT" => {
1088 if args.is_empty() {
1092 if context.has_frame() {
1094 Ok(context
1095 .get_frame_count(row_index, None)
1096 .unwrap_or(DataValue::Null))
1097 } else {
1098 Ok(context
1099 .get_partition_count(row_index, None)
1100 .unwrap_or(DataValue::Null))
1101 }
1102 } else {
1103 let column = match &args[0] {
1105 SqlExpression::Column(col) => {
1106 if col.name == "*" {
1107 if context.has_frame() {
1109 return Ok(context
1110 .get_frame_count(row_index, None)
1111 .unwrap_or(DataValue::Null));
1112 } else {
1113 return Ok(context
1114 .get_partition_count(row_index, None)
1115 .unwrap_or(DataValue::Null));
1116 }
1117 }
1118 col.clone()
1119 }
1120 SqlExpression::StringLiteral(s) if s == "*" => {
1121 if context.has_frame() {
1123 return Ok(context
1124 .get_frame_count(row_index, None)
1125 .unwrap_or(DataValue::Null));
1126 } else {
1127 return Ok(context
1128 .get_partition_count(row_index, None)
1129 .unwrap_or(DataValue::Null));
1130 }
1131 }
1132 _ => return Err(anyhow!("COUNT argument must be a column or *")),
1133 };
1134
1135 if context.has_frame() {
1137 Ok(context
1138 .get_frame_count(row_index, Some(&column.name))
1139 .unwrap_or(DataValue::Null))
1140 } else {
1141 Ok(context
1142 .get_partition_count(row_index, Some(&column.name))
1143 .unwrap_or(DataValue::Null))
1144 }
1145 }
1146 }
1147 _ => Err(anyhow!("Unknown window function: {}", name)),
1148 }
1149 }
1150
1151 fn evaluate_method_call(
1153 &mut self,
1154 object: &str,
1155 method: &str,
1156 args: &[SqlExpression],
1157 row_index: usize,
1158 ) -> Result<DataValue> {
1159 let col_index = self.table.get_column_index(object).ok_or_else(|| {
1161 let suggestion = self.find_similar_column(object);
1162 match suggestion {
1163 Some(similar) => {
1164 anyhow!("Column '{}' not found. Did you mean '{}'?", object, similar)
1165 }
1166 None => anyhow!("Column '{}' not found", object),
1167 }
1168 })?;
1169
1170 let cell_value = self.table.get_value(row_index, col_index).cloned();
1171
1172 self.evaluate_method_on_value(
1173 &cell_value.unwrap_or(DataValue::Null),
1174 method,
1175 args,
1176 row_index,
1177 )
1178 }
1179
1180 fn evaluate_method_on_value(
1182 &mut self,
1183 value: &DataValue,
1184 method: &str,
1185 args: &[SqlExpression],
1186 row_index: usize,
1187 ) -> Result<DataValue> {
1188 let function_name = match method.to_lowercase().as_str() {
1193 "trim" => "TRIM",
1194 "trimstart" | "trimbegin" => "TRIMSTART",
1195 "trimend" => "TRIMEND",
1196 "length" | "len" => "LENGTH",
1197 "contains" => "CONTAINS",
1198 "startswith" => "STARTSWITH",
1199 "endswith" => "ENDSWITH",
1200 "indexof" => "INDEXOF",
1201 _ => method, };
1203
1204 if self.function_registry.get(function_name).is_some() {
1206 debug!(
1207 "Proxying method '{}' through function registry as '{}'",
1208 method, function_name
1209 );
1210
1211 let mut func_args = vec![value.clone()];
1213
1214 for arg in args {
1216 func_args.push(self.evaluate(arg, row_index)?);
1217 }
1218
1219 let func = self.function_registry.get(function_name).unwrap();
1221 return func.evaluate(&func_args);
1222 }
1223
1224 Err(anyhow!(
1227 "Method '{}' not found. It should be registered in the function registry.",
1228 method
1229 ))
1230 }
1231
1232 fn evaluate_case_expression(
1234 &mut self,
1235 when_branches: &[crate::sql::recursive_parser::WhenBranch],
1236 else_branch: &Option<Box<SqlExpression>>,
1237 row_index: usize,
1238 ) -> Result<DataValue> {
1239 debug!(
1240 "ArithmeticEvaluator: evaluating CASE expression for row {}",
1241 row_index
1242 );
1243
1244 for branch in when_branches {
1246 let condition_result = self.evaluate_condition_as_bool(&branch.condition, row_index)?;
1248
1249 if condition_result {
1250 debug!("CASE: WHEN condition matched, evaluating result expression");
1251 return self.evaluate(&branch.result, row_index);
1252 }
1253 }
1254
1255 if let Some(else_expr) = else_branch {
1257 debug!("CASE: No WHEN matched, evaluating ELSE expression");
1258 self.evaluate(else_expr, row_index)
1259 } else {
1260 debug!("CASE: No WHEN matched and no ELSE, returning NULL");
1261 Ok(DataValue::Null)
1262 }
1263 }
1264
1265 fn evaluate_simple_case_expression(
1267 &mut self,
1268 expr: &Box<SqlExpression>,
1269 when_branches: &[crate::sql::parser::ast::SimpleWhenBranch],
1270 else_branch: &Option<Box<SqlExpression>>,
1271 row_index: usize,
1272 ) -> Result<DataValue> {
1273 debug!(
1274 "ArithmeticEvaluator: evaluating simple CASE expression for row {}",
1275 row_index
1276 );
1277
1278 let case_value = self.evaluate(expr, row_index)?;
1280 debug!("Simple CASE: evaluated expression to {:?}", case_value);
1281
1282 for branch in when_branches {
1284 let when_value = self.evaluate(&branch.value, row_index)?;
1286
1287 if self.values_equal(&case_value, &when_value)? {
1289 debug!("Simple CASE: WHEN value matched, evaluating result expression");
1290 return self.evaluate(&branch.result, row_index);
1291 }
1292 }
1293
1294 if let Some(else_expr) = else_branch {
1296 debug!("Simple CASE: No WHEN matched, evaluating ELSE expression");
1297 self.evaluate(else_expr, row_index)
1298 } else {
1299 debug!("Simple CASE: No WHEN matched and no ELSE, returning NULL");
1300 Ok(DataValue::Null)
1301 }
1302 }
1303
1304 fn values_equal(&self, left: &DataValue, right: &DataValue) -> Result<bool> {
1306 match (left, right) {
1307 (DataValue::Null, DataValue::Null) => Ok(true),
1308 (DataValue::Null, _) | (_, DataValue::Null) => Ok(false),
1309 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(a == b),
1310 (DataValue::Float(a), DataValue::Float(b)) => Ok((a - b).abs() < f64::EPSILON),
1311 (DataValue::String(a), DataValue::String(b)) => Ok(a == b),
1312 (DataValue::Boolean(a), DataValue::Boolean(b)) => Ok(a == b),
1313 (DataValue::DateTime(a), DataValue::DateTime(b)) => Ok(a == b),
1314 (DataValue::Integer(a), DataValue::Float(b)) => {
1316 Ok((*a as f64 - b).abs() < f64::EPSILON)
1317 }
1318 (DataValue::Float(a), DataValue::Integer(b)) => {
1319 Ok((a - *b as f64).abs() < f64::EPSILON)
1320 }
1321 _ => Ok(false),
1322 }
1323 }
1324
1325 fn evaluate_condition_as_bool(
1327 &mut self,
1328 expr: &SqlExpression,
1329 row_index: usize,
1330 ) -> Result<bool> {
1331 let value = self.evaluate(expr, row_index)?;
1332
1333 match value {
1334 DataValue::Boolean(b) => Ok(b),
1335 DataValue::Integer(i) => Ok(i != 0),
1336 DataValue::Float(f) => Ok(f != 0.0),
1337 DataValue::Null => Ok(false),
1338 DataValue::String(s) => Ok(!s.is_empty()),
1339 DataValue::InternedString(s) => Ok(!s.is_empty()),
1340 _ => Ok(true), }
1342 }
1343
1344 fn evaluate_datetime_constructor(
1346 &self,
1347 year: i32,
1348 month: u32,
1349 day: u32,
1350 hour: Option<u32>,
1351 minute: Option<u32>,
1352 second: Option<u32>,
1353 ) -> Result<DataValue> {
1354 use chrono::{NaiveDate, TimeZone, Utc};
1355
1356 let date = NaiveDate::from_ymd_opt(year, month, day)
1358 .ok_or_else(|| anyhow!("Invalid date: {}-{}-{}", year, month, day))?;
1359
1360 let hour = hour.unwrap_or(0);
1362 let minute = minute.unwrap_or(0);
1363 let second = second.unwrap_or(0);
1364
1365 let naive_datetime = date
1366 .and_hms_opt(hour, minute, second)
1367 .ok_or_else(|| anyhow!("Invalid time: {}:{}:{}", hour, minute, second))?;
1368
1369 let datetime = Utc.from_utc_datetime(&naive_datetime);
1371
1372 let datetime_str = datetime.format("%Y-%m-%d %H:%M:%S%.3f").to_string();
1374 Ok(DataValue::String(datetime_str))
1375 }
1376
1377 fn evaluate_datetime_today(
1379 &self,
1380 hour: Option<u32>,
1381 minute: Option<u32>,
1382 second: Option<u32>,
1383 ) -> Result<DataValue> {
1384 use chrono::{TimeZone, Utc};
1385
1386 let today = Utc::now().date_naive();
1388
1389 let hour = hour.unwrap_or(0);
1391 let minute = minute.unwrap_or(0);
1392 let second = second.unwrap_or(0);
1393
1394 let naive_datetime = today
1395 .and_hms_opt(hour, minute, second)
1396 .ok_or_else(|| anyhow!("Invalid time: {}:{}:{}", hour, minute, second))?;
1397
1398 let datetime = Utc.from_utc_datetime(&naive_datetime);
1400
1401 let datetime_str = datetime.format("%Y-%m-%d %H:%M:%S%.3f").to_string();
1403 Ok(DataValue::String(datetime_str))
1404 }
1405}
1406
1407#[cfg(test)]
1408mod tests {
1409 use super::*;
1410 use crate::data::datatable::{DataColumn, DataRow};
1411
1412 fn create_test_table() -> DataTable {
1413 let mut table = DataTable::new("test");
1414 table.add_column(DataColumn::new("a"));
1415 table.add_column(DataColumn::new("b"));
1416 table.add_column(DataColumn::new("c"));
1417
1418 table
1419 .add_row(DataRow::new(vec![
1420 DataValue::Integer(10),
1421 DataValue::Float(2.5),
1422 DataValue::Integer(4),
1423 ]))
1424 .unwrap();
1425
1426 table
1427 }
1428
1429 #[test]
1430 fn test_evaluate_column() {
1431 let table = create_test_table();
1432 let mut evaluator = ArithmeticEvaluator::new(&table);
1433
1434 let expr = SqlExpression::Column(ColumnRef::unquoted("a".to_string()));
1435 let result = evaluator.evaluate(&expr, 0).unwrap();
1436 assert_eq!(result, DataValue::Integer(10));
1437 }
1438
1439 #[test]
1440 fn test_evaluate_number_literal() {
1441 let table = create_test_table();
1442 let mut evaluator = ArithmeticEvaluator::new(&table);
1443
1444 let expr = SqlExpression::NumberLiteral("42".to_string());
1445 let result = evaluator.evaluate(&expr, 0).unwrap();
1446 assert_eq!(result, DataValue::Integer(42));
1447
1448 let expr = SqlExpression::NumberLiteral("3.14".to_string());
1449 let result = evaluator.evaluate(&expr, 0).unwrap();
1450 assert_eq!(result, DataValue::Float(3.14));
1451 }
1452
1453 #[test]
1454 fn test_add_values() {
1455 let table = create_test_table();
1456 let mut evaluator = ArithmeticEvaluator::new(&table);
1457
1458 let result = evaluator
1460 .add_values(&DataValue::Integer(5), &DataValue::Integer(3))
1461 .unwrap();
1462 assert_eq!(result, DataValue::Integer(8));
1463
1464 let result = evaluator
1466 .add_values(&DataValue::Integer(5), &DataValue::Float(2.5))
1467 .unwrap();
1468 assert_eq!(result, DataValue::Float(7.5));
1469 }
1470
1471 #[test]
1472 fn test_multiply_values() {
1473 let table = create_test_table();
1474 let mut evaluator = ArithmeticEvaluator::new(&table);
1475
1476 let result = evaluator
1478 .multiply_values(&DataValue::Integer(4), &DataValue::Float(2.5))
1479 .unwrap();
1480 assert_eq!(result, DataValue::Float(10.0));
1481 }
1482
1483 #[test]
1484 fn test_divide_values() {
1485 let table = create_test_table();
1486 let mut evaluator = ArithmeticEvaluator::new(&table);
1487
1488 let result = evaluator
1490 .divide_values(&DataValue::Integer(10), &DataValue::Integer(2))
1491 .unwrap();
1492 assert_eq!(result, DataValue::Integer(5));
1493
1494 let result = evaluator
1496 .divide_values(&DataValue::Integer(10), &DataValue::Integer(3))
1497 .unwrap();
1498 assert_eq!(result, DataValue::Float(10.0 / 3.0));
1499 }
1500
1501 #[test]
1502 fn test_division_by_zero() {
1503 let table = create_test_table();
1504 let mut evaluator = ArithmeticEvaluator::new(&table);
1505
1506 let result = evaluator.divide_values(&DataValue::Integer(10), &DataValue::Integer(0));
1507 assert!(result.is_err());
1508 assert!(result.unwrap_err().to_string().contains("Division by zero"));
1509 }
1510
1511 #[test]
1512 fn test_binary_op_expression() {
1513 let table = create_test_table();
1514 let mut evaluator = ArithmeticEvaluator::new(&table);
1515
1516 let expr = SqlExpression::BinaryOp {
1518 left: Box::new(SqlExpression::Column(ColumnRef::unquoted("a".to_string()))),
1519 op: "*".to_string(),
1520 right: Box::new(SqlExpression::Column(ColumnRef::unquoted("b".to_string()))),
1521 };
1522
1523 let result = evaluator.evaluate(&expr, 0).unwrap();
1524 assert_eq!(result, DataValue::Float(25.0));
1525 }
1526}