1use crate::config::global::get_date_notation;
2use crate::data::data_view::DataView;
3use crate::data::datatable::{DataTable, DataValue};
4use crate::data::value_comparisons::compare_with_op;
5use crate::sql::aggregate_functions::AggregateFunctionRegistry; use crate::sql::aggregates::AggregateRegistry; use crate::sql::functions::FunctionRegistry;
8use crate::sql::parser::ast::{ColumnRef, WindowSpec};
9use crate::sql::recursive_parser::SqlExpression;
10use crate::sql::window_context::WindowContext;
11use crate::sql::window_functions::{ExpressionEvaluator, WindowFunctionRegistry};
12use anyhow::{anyhow, Result};
13use std::collections::{HashMap, HashSet};
14use std::sync::Arc;
15use tracing::debug;
16
17pub struct ArithmeticEvaluator<'a> {
20 table: &'a DataTable,
21 _date_notation: String,
22 function_registry: Arc<FunctionRegistry>,
23 aggregate_registry: Arc<AggregateRegistry>, new_aggregate_registry: Arc<AggregateFunctionRegistry>, window_function_registry: Arc<WindowFunctionRegistry>,
26 visible_rows: Option<Vec<usize>>, window_contexts: HashMap<String, Arc<WindowContext>>, table_aliases: HashMap<String, String>, }
30
31impl<'a> ArithmeticEvaluator<'a> {
32 #[must_use]
33 pub fn new(table: &'a DataTable) -> Self {
34 Self {
35 table,
36 _date_notation: get_date_notation(),
37 function_registry: Arc::new(FunctionRegistry::new()),
38 aggregate_registry: Arc::new(AggregateRegistry::new()),
39 new_aggregate_registry: Arc::new(AggregateFunctionRegistry::new()),
40 window_function_registry: Arc::new(WindowFunctionRegistry::new()),
41 visible_rows: None,
42 window_contexts: HashMap::new(),
43 table_aliases: HashMap::new(),
44 }
45 }
46
47 #[must_use]
48 pub fn with_date_notation(table: &'a DataTable, date_notation: String) -> Self {
49 Self {
50 table,
51 _date_notation: date_notation,
52 function_registry: Arc::new(FunctionRegistry::new()),
53 aggregate_registry: Arc::new(AggregateRegistry::new()),
54 new_aggregate_registry: Arc::new(AggregateFunctionRegistry::new()),
55 window_function_registry: Arc::new(WindowFunctionRegistry::new()),
56 visible_rows: None,
57 window_contexts: HashMap::new(),
58 table_aliases: HashMap::new(),
59 }
60 }
61
62 #[must_use]
64 pub fn with_visible_rows(mut self, rows: Vec<usize>) -> Self {
65 self.visible_rows = Some(rows);
66 self
67 }
68
69 #[must_use]
71 pub fn with_table_aliases(mut self, aliases: HashMap<String, String>) -> Self {
72 self.table_aliases = aliases;
73 self
74 }
75
76 #[must_use]
77 pub fn with_date_notation_and_registry(
78 table: &'a DataTable,
79 date_notation: String,
80 function_registry: Arc<FunctionRegistry>,
81 ) -> Self {
82 Self {
83 table,
84 _date_notation: date_notation,
85 function_registry,
86 aggregate_registry: Arc::new(AggregateRegistry::new()),
87 new_aggregate_registry: Arc::new(AggregateFunctionRegistry::new()),
88 window_function_registry: Arc::new(WindowFunctionRegistry::new()),
89 visible_rows: None,
90 window_contexts: HashMap::new(),
91 table_aliases: HashMap::new(),
92 }
93 }
94
95 fn find_similar_column(&self, name: &str) -> Option<String> {
97 let columns = self.table.column_names();
98 let mut best_match: Option<(String, usize)> = None;
99
100 for col in columns {
101 let distance = self.edit_distance(&col.to_lowercase(), &name.to_lowercase());
102 let max_distance = if name.len() > 10 { 3 } else { 2 };
105 if distance <= max_distance {
106 match &best_match {
107 None => best_match = Some((col, distance)),
108 Some((_, best_dist)) if distance < *best_dist => {
109 best_match = Some((col, distance));
110 }
111 _ => {}
112 }
113 }
114 }
115
116 best_match.map(|(name, _)| name)
117 }
118
119 fn edit_distance(&self, s1: &str, s2: &str) -> usize {
121 crate::sql::functions::string_methods::EditDistanceFunction::calculate_edit_distance(s1, s2)
123 }
124
125 pub fn evaluate(&mut self, expr: &SqlExpression, row_index: usize) -> Result<DataValue> {
127 debug!(
128 "ArithmeticEvaluator: evaluating {:?} for row {}",
129 expr, row_index
130 );
131
132 match expr {
133 SqlExpression::Column(column_ref) => self.evaluate_column_ref(column_ref, row_index),
134 SqlExpression::StringLiteral(s) => Ok(DataValue::String(s.clone())),
135 SqlExpression::BooleanLiteral(b) => Ok(DataValue::Boolean(*b)),
136 SqlExpression::NumberLiteral(n) => self.evaluate_number_literal(n),
137 SqlExpression::Null => Ok(DataValue::Null),
138 SqlExpression::BinaryOp { left, op, right } => {
139 self.evaluate_binary_op(left, op, right, row_index)
140 }
141 SqlExpression::FunctionCall {
142 name,
143 args,
144 distinct,
145 } => self.evaluate_function_with_distinct(name, args, *distinct, row_index),
146 SqlExpression::WindowFunction {
147 name,
148 args,
149 window_spec,
150 } => self.evaluate_window_function(name, args, window_spec, row_index),
151 SqlExpression::MethodCall {
152 object,
153 method,
154 args,
155 } => self.evaluate_method_call(object, method, args, row_index),
156 SqlExpression::ChainedMethodCall { base, method, args } => {
157 let base_value = self.evaluate(base, row_index)?;
159 self.evaluate_method_on_value(&base_value, method, args, row_index)
160 }
161 SqlExpression::CaseExpression {
162 when_branches,
163 else_branch,
164 } => self.evaluate_case_expression(when_branches, else_branch, row_index),
165 SqlExpression::SimpleCaseExpression {
166 expr,
167 when_branches,
168 else_branch,
169 } => self.evaluate_simple_case_expression(expr, when_branches, else_branch, row_index),
170 SqlExpression::DateTimeConstructor {
171 year,
172 month,
173 day,
174 hour,
175 minute,
176 second,
177 } => self.evaluate_datetime_constructor(*year, *month, *day, *hour, *minute, *second),
178 SqlExpression::DateTimeToday {
179 hour,
180 minute,
181 second,
182 } => self.evaluate_datetime_today(*hour, *minute, *second),
183 _ => Err(anyhow!(
184 "Unsupported expression type for arithmetic evaluation: {:?}",
185 expr
186 )),
187 }
188 }
189
190 fn evaluate_column_ref(&self, column_ref: &ColumnRef, row_index: usize) -> Result<DataValue> {
192 if let Some(table_prefix) = &column_ref.table_prefix {
193 let actual_table = self
195 .table_aliases
196 .get(table_prefix)
197 .map(|s| s.as_str())
198 .unwrap_or(table_prefix);
199
200 let qualified_name = format!("{}.{}", actual_table, column_ref.name);
202
203 if let Some(col_idx) = self.table.find_column_by_qualified_name(&qualified_name) {
204 debug!(
205 "Resolved {}.{} -> '{}' at index {}",
206 table_prefix, column_ref.name, qualified_name, col_idx
207 );
208 return self
209 .table
210 .get_value(row_index, col_idx)
211 .ok_or_else(|| anyhow!("Row {} out of bounds", row_index))
212 .map(|v| v.clone());
213 }
214
215 if let Some(col_idx) = self.table.get_column_index(&column_ref.name) {
217 debug!(
218 "Resolved {}.{} -> unqualified '{}' at index {}",
219 table_prefix, column_ref.name, column_ref.name, col_idx
220 );
221 return self
222 .table
223 .get_value(row_index, col_idx)
224 .ok_or_else(|| anyhow!("Row {} out of bounds", row_index))
225 .map(|v| v.clone());
226 }
227
228 Err(anyhow!(
230 "Column '{}' not found. Table '{}' may not support qualified column names",
231 qualified_name,
232 actual_table
233 ))
234 } else {
235 self.evaluate_column(&column_ref.name, row_index)
237 }
238 }
239
240 fn evaluate_column(&self, column_name: &str, row_index: usize) -> Result<DataValue> {
242 let resolved_column = if column_name.contains('.') {
244 if let Some(dot_pos) = column_name.rfind('.') {
246 let _table_or_alias = &column_name[..dot_pos];
247 let col_name = &column_name[dot_pos + 1..];
248
249 debug!(
252 "Resolving qualified column: {} -> {}",
253 column_name, col_name
254 );
255 col_name.to_string()
256 } else {
257 column_name.to_string()
258 }
259 } else {
260 column_name.to_string()
261 };
262
263 let col_index = if let Some(idx) = self.table.get_column_index(&resolved_column) {
264 idx
265 } else if resolved_column != column_name {
266 if let Some(idx) = self.table.get_column_index(column_name) {
268 idx
269 } else {
270 let suggestion = self.find_similar_column(&resolved_column);
271 return Err(match suggestion {
272 Some(similar) => anyhow!(
273 "Column '{}' not found. Did you mean '{}'?",
274 column_name,
275 similar
276 ),
277 None => anyhow!("Column '{}' not found", column_name),
278 });
279 }
280 } else {
281 let suggestion = self.find_similar_column(&resolved_column);
282 return Err(match suggestion {
283 Some(similar) => anyhow!(
284 "Column '{}' not found. Did you mean '{}'?",
285 column_name,
286 similar
287 ),
288 None => anyhow!("Column '{}' not found", column_name),
289 });
290 };
291
292 if row_index >= self.table.row_count() {
293 return Err(anyhow!("Row index {} out of bounds", row_index));
294 }
295
296 let row = self
297 .table
298 .get_row(row_index)
299 .ok_or_else(|| anyhow!("Row {} not found", row_index))?;
300
301 let value = row
302 .get(col_index)
303 .ok_or_else(|| anyhow!("Column index {} out of bounds for row", col_index))?;
304
305 Ok(value.clone())
306 }
307
308 fn evaluate_number_literal(&self, number_str: &str) -> Result<DataValue> {
310 if let Ok(int_val) = number_str.parse::<i64>() {
312 return Ok(DataValue::Integer(int_val));
313 }
314
315 if let Ok(float_val) = number_str.parse::<f64>() {
317 return Ok(DataValue::Float(float_val));
318 }
319
320 Err(anyhow!("Invalid number literal: {}", number_str))
321 }
322
323 fn evaluate_binary_op(
325 &mut self,
326 left: &SqlExpression,
327 op: &str,
328 right: &SqlExpression,
329 row_index: usize,
330 ) -> Result<DataValue> {
331 let left_val = self.evaluate(left, row_index)?;
332 let right_val = self.evaluate(right, row_index)?;
333
334 debug!(
335 "ArithmeticEvaluator: {} {} {}",
336 self.format_value(&left_val),
337 op,
338 self.format_value(&right_val)
339 );
340
341 match op {
342 "+" => self.add_values(&left_val, &right_val),
343 "-" => self.subtract_values(&left_val, &right_val),
344 "*" => self.multiply_values(&left_val, &right_val),
345 "/" => self.divide_values(&left_val, &right_val),
346 "%" => {
347 let args = vec![left.clone(), right.clone()];
349 self.evaluate_function("MOD", &args, row_index)
350 }
351 ">" | "<" | ">=" | "<=" | "=" | "!=" | "<>" => {
354 let result = compare_with_op(&left_val, &right_val, op, false);
355 Ok(DataValue::Boolean(result))
356 }
357 "IS NULL" => Ok(DataValue::Boolean(matches!(left_val, DataValue::Null))),
359 "IS NOT NULL" => Ok(DataValue::Boolean(!matches!(left_val, DataValue::Null))),
360 "AND" => {
362 let left_bool = self.to_bool(&left_val)?;
363 let right_bool = self.to_bool(&right_val)?;
364 Ok(DataValue::Boolean(left_bool && right_bool))
365 }
366 "OR" => {
367 let left_bool = self.to_bool(&left_val)?;
368 let right_bool = self.to_bool(&right_val)?;
369 Ok(DataValue::Boolean(left_bool || right_bool))
370 }
371 _ => Err(anyhow!("Unsupported arithmetic operator: {}", op)),
372 }
373 }
374
375 fn add_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
377 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
379 return Ok(DataValue::Null);
380 }
381
382 match (left, right) {
383 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a + b)),
384 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 + b)),
385 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a + *b as f64)),
386 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a + b)),
387 _ => Err(anyhow!("Cannot add {:?} and {:?}", left, right)),
388 }
389 }
390
391 fn subtract_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
393 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
395 return Ok(DataValue::Null);
396 }
397
398 match (left, right) {
399 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a - b)),
400 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 - b)),
401 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a - *b as f64)),
402 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a - b)),
403 _ => Err(anyhow!("Cannot subtract {:?} and {:?}", left, right)),
404 }
405 }
406
407 fn multiply_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
409 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
411 return Ok(DataValue::Null);
412 }
413
414 match (left, right) {
415 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a * b)),
416 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 * b)),
417 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a * *b as f64)),
418 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a * b)),
419 _ => Err(anyhow!("Cannot multiply {:?} and {:?}", left, right)),
420 }
421 }
422
423 fn divide_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
425 if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
427 return Ok(DataValue::Null);
428 }
429
430 let is_zero = match right {
432 DataValue::Integer(0) => true,
433 DataValue::Float(f) if *f == 0.0 => true, _ => false,
435 };
436
437 if is_zero {
438 return Err(anyhow!("Division by zero"));
439 }
440
441 match (left, right) {
442 (DataValue::Integer(a), DataValue::Integer(b)) => {
443 if a % b == 0 {
445 Ok(DataValue::Integer(a / b))
446 } else {
447 Ok(DataValue::Float(*a as f64 / *b as f64))
448 }
449 }
450 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 / b)),
451 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a / *b as f64)),
452 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a / b)),
453 _ => Err(anyhow!("Cannot divide {:?} and {:?}", left, right)),
454 }
455 }
456
457 fn format_value(&self, value: &DataValue) -> String {
459 match value {
460 DataValue::Integer(i) => i.to_string(),
461 DataValue::Float(f) => f.to_string(),
462 DataValue::String(s) => format!("'{s}'"),
463 _ => format!("{value:?}"),
464 }
465 }
466
467 fn to_bool(&self, value: &DataValue) -> Result<bool> {
469 match value {
470 DataValue::Boolean(b) => Ok(*b),
471 DataValue::Integer(i) => Ok(*i != 0),
472 DataValue::Float(f) => Ok(*f != 0.0),
473 DataValue::Null => Ok(false),
474 _ => Err(anyhow!("Cannot convert {:?} to boolean", value)),
475 }
476 }
477
478 fn evaluate_function_with_distinct(
480 &mut self,
481 name: &str,
482 args: &[SqlExpression],
483 distinct: bool,
484 row_index: usize,
485 ) -> Result<DataValue> {
486 if distinct {
488 let name_upper = name.to_uppercase();
489
490 if self.aggregate_registry.is_aggregate(&name_upper)
492 || self.new_aggregate_registry.contains(&name_upper)
493 {
494 return self.evaluate_aggregate_with_distinct(&name_upper, args, row_index);
495 } else {
496 return Err(anyhow!(
497 "DISTINCT can only be used with aggregate functions"
498 ));
499 }
500 }
501
502 self.evaluate_function(name, args, row_index)
504 }
505
506 fn evaluate_aggregate_with_distinct(
507 &mut self,
508 name: &str,
509 args: &[SqlExpression],
510 _row_index: usize,
511 ) -> Result<DataValue> {
512 let name_upper = name.to_uppercase();
513
514 if self.new_aggregate_registry.get(&name_upper).is_some() {
516 let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
517 visible.clone()
518 } else {
519 (0..self.table.rows.len()).collect()
520 };
521
522 let mut vals = Vec::new();
524 for &row_idx in &rows_to_process {
525 if !args.is_empty() {
526 let value = self.evaluate(&args[0], row_idx)?;
527 vals.push(value);
528 }
529 }
530
531 let mut seen = HashSet::new();
533 let unique_values: Vec<_> = vals
534 .into_iter()
535 .filter(|v| {
536 let key = format!("{:?}", v);
537 seen.insert(key)
538 })
539 .collect();
540
541 let agg_func = self.new_aggregate_registry.get(&name_upper).unwrap();
543 let mut state = agg_func.create_state();
544
545 for value in &unique_values {
547 state.accumulate(value)?;
548 }
549
550 return Ok(state.finalize());
551 }
552
553 if self.aggregate_registry.get(&name_upper).is_some() {
555 let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
557 visible.clone()
558 } else {
559 (0..self.table.rows.len()).collect()
560 };
561
562 if name_upper == "STRING_AGG" && args.len() >= 2 {
564 let mut state = crate::sql::aggregates::AggregateState::StringAgg(
566 if args.len() >= 2 {
568 let separator = self.evaluate(&args[1], 0)?; match separator {
570 DataValue::String(s) => crate::sql::aggregates::StringAggState::new(&s),
571 DataValue::InternedString(s) => {
572 crate::sql::aggregates::StringAggState::new(&s)
573 }
574 _ => crate::sql::aggregates::StringAggState::new(","), }
576 } else {
577 crate::sql::aggregates::StringAggState::new(",")
578 },
579 );
580
581 let mut seen_values = HashSet::new();
584
585 for &row_idx in &rows_to_process {
586 let value = self.evaluate(&args[0], row_idx)?;
587
588 if !seen_values.insert(value.clone()) {
590 continue; }
592
593 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
595 agg_func.accumulate(&mut state, &value)?;
596 }
597
598 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
600 return Ok(agg_func.finalize(state));
601 }
602
603 let mut vals = Vec::new();
606 for &row_idx in &rows_to_process {
607 if !args.is_empty() {
608 let value = self.evaluate(&args[0], row_idx)?;
609 vals.push(value);
610 }
611 }
612
613 let mut seen = HashSet::new();
615 let mut unique_values = Vec::new();
616 for value in vals {
617 if seen.insert(value.clone()) {
618 unique_values.push(value);
619 }
620 }
621
622 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
624 let mut state = agg_func.init();
625
626 for value in &unique_values {
628 agg_func.accumulate(&mut state, value)?;
629 }
630
631 return Ok(agg_func.finalize(state));
632 }
633
634 Err(anyhow!("Unknown aggregate function: {}", name))
635 }
636
637 fn evaluate_function(
638 &mut self,
639 name: &str,
640 args: &[SqlExpression],
641 row_index: usize,
642 ) -> Result<DataValue> {
643 let name_upper = name.to_uppercase();
645
646 if self.new_aggregate_registry.get(&name_upper).is_some() {
648 let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
650 visible.clone()
651 } else {
652 (0..self.table.rows.len()).collect()
653 };
654
655 let agg_func = self.new_aggregate_registry.get(&name_upper).unwrap();
657 let mut state = agg_func.create_state();
658
659 if name_upper == "COUNT" || name_upper == "COUNT_STAR" {
661 if args.is_empty()
662 || (args.len() == 1
663 && matches!(&args[0], SqlExpression::Column(col) if col.name == "*"))
664 || (args.len() == 1
665 && matches!(&args[0], SqlExpression::StringLiteral(s) if s == "*"))
666 {
667 for _ in &rows_to_process {
669 state.accumulate(&DataValue::Integer(1))?;
670 }
671 } else {
672 for &row_idx in &rows_to_process {
674 let value = self.evaluate(&args[0], row_idx)?;
675 state.accumulate(&value)?;
676 }
677 }
678 } else {
679 if !args.is_empty() {
681 for &row_idx in &rows_to_process {
682 let value = self.evaluate(&args[0], row_idx)?;
683 state.accumulate(&value)?;
684 }
685 }
686 }
687
688 return Ok(state.finalize());
689 }
690
691 if self.aggregate_registry.get(&name_upper).is_some() {
693 let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
695 visible.clone()
696 } else {
697 (0..self.table.rows.len()).collect()
698 };
699
700 if name_upper == "STRING_AGG" && args.len() >= 2 {
702 let mut state = crate::sql::aggregates::AggregateState::StringAgg(
704 if args.len() >= 2 {
706 let separator = self.evaluate(&args[1], 0)?; match separator {
708 DataValue::String(s) => crate::sql::aggregates::StringAggState::new(&s),
709 DataValue::InternedString(s) => {
710 crate::sql::aggregates::StringAggState::new(&s)
711 }
712 _ => crate::sql::aggregates::StringAggState::new(","), }
714 } else {
715 crate::sql::aggregates::StringAggState::new(",")
716 },
717 );
718
719 for &row_idx in &rows_to_process {
721 let value = self.evaluate(&args[0], row_idx)?;
722 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
724 agg_func.accumulate(&mut state, &value)?;
725 }
726
727 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
729 return Ok(agg_func.finalize(state));
730 }
731
732 let values = if !args.is_empty()
734 && !(args.len() == 1
735 && matches!(&args[0], SqlExpression::Column(c) if c.name == "*"))
736 {
737 let mut vals = Vec::new();
739 for &row_idx in &rows_to_process {
740 let value = self.evaluate(&args[0], row_idx)?;
741 vals.push(value);
742 }
743 Some(vals)
744 } else {
745 None
746 };
747
748 let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
750 let mut state = agg_func.init();
751
752 if let Some(values) = values {
753 for value in &values {
755 agg_func.accumulate(&mut state, value)?;
756 }
757 } else {
758 for _ in &rows_to_process {
760 agg_func.accumulate(&mut state, &DataValue::Integer(1))?;
761 }
762 }
763
764 return Ok(agg_func.finalize(state));
765 }
766
767 if self.function_registry.get(name).is_some() {
769 let mut evaluated_args = Vec::new();
771 for arg in args {
772 evaluated_args.push(self.evaluate(arg, row_index)?);
773 }
774
775 let func = self.function_registry.get(name).unwrap();
777 return func.evaluate(&evaluated_args);
778 }
779
780 Err(anyhow!("Unknown function: {}", name))
782 }
783
784 fn get_or_create_window_context(&mut self, spec: &WindowSpec) -> Result<Arc<WindowContext>> {
786 let key = format!("{:?}", spec);
788
789 if let Some(context) = self.window_contexts.get(&key) {
790 return Ok(Arc::clone(context));
791 }
792
793 let data_view = if let Some(ref _visible_rows) = self.visible_rows {
795 let view = DataView::new(Arc::new(self.table.clone()));
797 view
800 } else {
801 DataView::new(Arc::new(self.table.clone()))
802 };
803
804 let context = WindowContext::new_with_spec(Arc::new(data_view), spec.clone())?;
806
807 let context = Arc::new(context);
808 self.window_contexts.insert(key, Arc::clone(&context));
809 Ok(context)
810 }
811
812 fn evaluate_window_function(
814 &mut self,
815 name: &str,
816 args: &[SqlExpression],
817 spec: &WindowSpec,
818 row_index: usize,
819 ) -> Result<DataValue> {
820 let name_upper = name.to_uppercase();
821
822 debug!("Looking for window function {} in registry", name_upper);
824 if let Some(window_fn_arc) = self.window_function_registry.get(&name_upper) {
825 debug!("Found window function {} in registry", name_upper);
826
827 let window_fn = window_fn_arc.as_ref();
829
830 window_fn.validate_args(args)?;
832
833 let transformed_spec = window_fn.transform_window_spec(spec, args)?;
835
836 let context = self.get_or_create_window_context(&transformed_spec)?;
838
839 struct EvaluatorAdapter<'a, 'b> {
841 evaluator: &'a mut ArithmeticEvaluator<'b>,
842 row_index: usize,
843 }
844
845 impl<'a, 'b> ExpressionEvaluator for EvaluatorAdapter<'a, 'b> {
846 fn evaluate(
847 &mut self,
848 expr: &SqlExpression,
849 _row_index: usize,
850 ) -> Result<DataValue> {
851 self.evaluator.evaluate(expr, self.row_index)
852 }
853 }
854
855 let mut adapter = EvaluatorAdapter {
856 evaluator: self,
857 row_index,
858 };
859
860 return window_fn.compute(&context, row_index, args, &mut adapter);
862 }
863
864 let context = self.get_or_create_window_context(spec)?;
866
867 match name_upper.as_str() {
868 "LAG" => {
869 if args.is_empty() {
871 return Err(anyhow!("LAG requires at least 1 argument"));
872 }
873
874 let column = match &args[0] {
876 SqlExpression::Column(col) => col.clone(),
877 _ => return Err(anyhow!("LAG first argument must be a column")),
878 };
879
880 let offset = if args.len() > 1 {
882 match self.evaluate(&args[1], row_index)? {
883 DataValue::Integer(i) => i as i32,
884 _ => return Err(anyhow!("LAG offset must be an integer")),
885 }
886 } else {
887 1
888 };
889
890 Ok(context
892 .get_offset_value(row_index, -offset, &column.name)
893 .unwrap_or(DataValue::Null))
894 }
895 "LEAD" => {
896 if args.is_empty() {
898 return Err(anyhow!("LEAD requires at least 1 argument"));
899 }
900
901 let column = match &args[0] {
903 SqlExpression::Column(col) => col.clone(),
904 _ => return Err(anyhow!("LEAD first argument must be a column")),
905 };
906
907 let offset = if args.len() > 1 {
909 match self.evaluate(&args[1], row_index)? {
910 DataValue::Integer(i) => i as i32,
911 _ => return Err(anyhow!("LEAD offset must be an integer")),
912 }
913 } else {
914 1
915 };
916
917 Ok(context
919 .get_offset_value(row_index, offset, &column.name)
920 .unwrap_or(DataValue::Null))
921 }
922 "ROW_NUMBER" => {
923 Ok(DataValue::Integer(context.get_row_number(row_index) as i64))
925 }
926 "FIRST_VALUE" => {
927 if args.is_empty() {
929 return Err(anyhow!("FIRST_VALUE requires 1 argument"));
930 }
931
932 let column = match &args[0] {
933 SqlExpression::Column(col) => col.clone(),
934 _ => return Err(anyhow!("FIRST_VALUE argument must be a column")),
935 };
936
937 if context.has_frame() {
939 Ok(context
940 .get_frame_first_value(row_index, &column.name)
941 .unwrap_or(DataValue::Null))
942 } else {
943 Ok(context
944 .get_first_value(row_index, &column.name)
945 .unwrap_or(DataValue::Null))
946 }
947 }
948 "LAST_VALUE" => {
949 if args.is_empty() {
951 return Err(anyhow!("LAST_VALUE requires 1 argument"));
952 }
953
954 let column = match &args[0] {
955 SqlExpression::Column(col) => col.clone(),
956 _ => return Err(anyhow!("LAST_VALUE argument must be a column")),
957 };
958
959 if context.has_frame() {
961 Ok(context
962 .get_frame_last_value(row_index, &column.name)
963 .unwrap_or(DataValue::Null))
964 } else {
965 Ok(context
966 .get_last_value(row_index, &column.name)
967 .unwrap_or(DataValue::Null))
968 }
969 }
970 "SUM" => {
971 if args.is_empty() {
973 return Err(anyhow!("SUM requires 1 argument"));
974 }
975
976 let column = match &args[0] {
977 SqlExpression::Column(col) => col.clone(),
978 _ => return Err(anyhow!("SUM argument must be a column")),
979 };
980
981 if context.has_frame() {
983 Ok(context
984 .get_frame_sum(row_index, &column.name)
985 .unwrap_or(DataValue::Null))
986 } else {
987 Ok(context
988 .get_partition_sum(row_index, &column.name)
989 .unwrap_or(DataValue::Null))
990 }
991 }
992 "AVG" => {
993 if args.is_empty() {
995 return Err(anyhow!("AVG requires 1 argument"));
996 }
997
998 let column = match &args[0] {
999 SqlExpression::Column(col) => col.clone(),
1000 _ => return Err(anyhow!("AVG argument must be a column")),
1001 };
1002
1003 Ok(context
1004 .get_frame_avg(row_index, &column.name)
1005 .unwrap_or(DataValue::Null))
1006 }
1007 "STDDEV" | "STDEV" => {
1008 if args.is_empty() {
1010 return Err(anyhow!("STDDEV requires 1 argument"));
1011 }
1012
1013 let column = match &args[0] {
1014 SqlExpression::Column(col) => col.clone(),
1015 _ => return Err(anyhow!("STDDEV argument must be a column")),
1016 };
1017
1018 Ok(context
1019 .get_frame_stddev(row_index, &column.name)
1020 .unwrap_or(DataValue::Null))
1021 }
1022 "VARIANCE" | "VAR" => {
1023 if args.is_empty() {
1025 return Err(anyhow!("VARIANCE requires 1 argument"));
1026 }
1027
1028 let column = match &args[0] {
1029 SqlExpression::Column(col) => col.clone(),
1030 _ => return Err(anyhow!("VARIANCE argument must be a column")),
1031 };
1032
1033 Ok(context
1034 .get_frame_variance(row_index, &column.name)
1035 .unwrap_or(DataValue::Null))
1036 }
1037 "MIN" => {
1038 if args.is_empty() {
1040 return Err(anyhow!("MIN requires 1 argument"));
1041 }
1042
1043 let column = match &args[0] {
1044 SqlExpression::Column(col) => col.clone(),
1045 _ => return Err(anyhow!("MIN argument must be a column")),
1046 };
1047
1048 let frame_rows = context.get_frame_rows(row_index);
1049 if frame_rows.is_empty() {
1050 return Ok(DataValue::Null);
1051 }
1052
1053 let source_table = context.source();
1054 let col_idx = source_table
1055 .get_column_index(&column.name)
1056 .ok_or_else(|| anyhow!("Column '{}' not found", column.name))?;
1057
1058 let mut min_value: Option<DataValue> = None;
1059 for &row_idx in &frame_rows {
1060 if let Some(value) = source_table.get_value(row_idx, col_idx) {
1061 if !matches!(value, DataValue::Null) {
1062 match &min_value {
1063 None => min_value = Some(value.clone()),
1064 Some(current_min) => {
1065 if value < current_min {
1066 min_value = Some(value.clone());
1067 }
1068 }
1069 }
1070 }
1071 }
1072 }
1073
1074 Ok(min_value.unwrap_or(DataValue::Null))
1075 }
1076 "MAX" => {
1077 if args.is_empty() {
1079 return Err(anyhow!("MAX requires 1 argument"));
1080 }
1081
1082 let column = match &args[0] {
1083 SqlExpression::Column(col) => col.clone(),
1084 _ => return Err(anyhow!("MAX argument must be a column")),
1085 };
1086
1087 let frame_rows = context.get_frame_rows(row_index);
1088 if frame_rows.is_empty() {
1089 return Ok(DataValue::Null);
1090 }
1091
1092 let source_table = context.source();
1093 let col_idx = source_table
1094 .get_column_index(&column.name)
1095 .ok_or_else(|| anyhow!("Column '{}' not found", column.name))?;
1096
1097 let mut max_value: Option<DataValue> = None;
1098 for &row_idx in &frame_rows {
1099 if let Some(value) = source_table.get_value(row_idx, col_idx) {
1100 if !matches!(value, DataValue::Null) {
1101 match &max_value {
1102 None => max_value = Some(value.clone()),
1103 Some(current_max) => {
1104 if value > current_max {
1105 max_value = Some(value.clone());
1106 }
1107 }
1108 }
1109 }
1110 }
1111 }
1112
1113 Ok(max_value.unwrap_or(DataValue::Null))
1114 }
1115 "COUNT" => {
1116 if args.is_empty() {
1120 if context.has_frame() {
1122 Ok(context
1123 .get_frame_count(row_index, None)
1124 .unwrap_or(DataValue::Null))
1125 } else {
1126 Ok(context
1127 .get_partition_count(row_index, None)
1128 .unwrap_or(DataValue::Null))
1129 }
1130 } else {
1131 let column = match &args[0] {
1133 SqlExpression::Column(col) => {
1134 if col.name == "*" {
1135 if context.has_frame() {
1137 return Ok(context
1138 .get_frame_count(row_index, None)
1139 .unwrap_or(DataValue::Null));
1140 } else {
1141 return Ok(context
1142 .get_partition_count(row_index, None)
1143 .unwrap_or(DataValue::Null));
1144 }
1145 }
1146 col.clone()
1147 }
1148 SqlExpression::StringLiteral(s) if s == "*" => {
1149 if context.has_frame() {
1151 return Ok(context
1152 .get_frame_count(row_index, None)
1153 .unwrap_or(DataValue::Null));
1154 } else {
1155 return Ok(context
1156 .get_partition_count(row_index, None)
1157 .unwrap_or(DataValue::Null));
1158 }
1159 }
1160 _ => return Err(anyhow!("COUNT argument must be a column or *")),
1161 };
1162
1163 if context.has_frame() {
1165 Ok(context
1166 .get_frame_count(row_index, Some(&column.name))
1167 .unwrap_or(DataValue::Null))
1168 } else {
1169 Ok(context
1170 .get_partition_count(row_index, Some(&column.name))
1171 .unwrap_or(DataValue::Null))
1172 }
1173 }
1174 }
1175 _ => Err(anyhow!("Unknown window function: {}", name)),
1176 }
1177 }
1178
1179 fn evaluate_method_call(
1181 &mut self,
1182 object: &str,
1183 method: &str,
1184 args: &[SqlExpression],
1185 row_index: usize,
1186 ) -> Result<DataValue> {
1187 let col_index = self.table.get_column_index(object).ok_or_else(|| {
1189 let suggestion = self.find_similar_column(object);
1190 match suggestion {
1191 Some(similar) => {
1192 anyhow!("Column '{}' not found. Did you mean '{}'?", object, similar)
1193 }
1194 None => anyhow!("Column '{}' not found", object),
1195 }
1196 })?;
1197
1198 let cell_value = self.table.get_value(row_index, col_index).cloned();
1199
1200 self.evaluate_method_on_value(
1201 &cell_value.unwrap_or(DataValue::Null),
1202 method,
1203 args,
1204 row_index,
1205 )
1206 }
1207
1208 fn evaluate_method_on_value(
1210 &mut self,
1211 value: &DataValue,
1212 method: &str,
1213 args: &[SqlExpression],
1214 row_index: usize,
1215 ) -> Result<DataValue> {
1216 let function_name = match method.to_lowercase().as_str() {
1221 "trim" => "TRIM",
1222 "trimstart" | "trimbegin" => "TRIMSTART",
1223 "trimend" => "TRIMEND",
1224 "length" | "len" => "LENGTH",
1225 "contains" => "CONTAINS",
1226 "startswith" => "STARTSWITH",
1227 "endswith" => "ENDSWITH",
1228 "indexof" => "INDEXOF",
1229 _ => method, };
1231
1232 if self.function_registry.get(function_name).is_some() {
1234 debug!(
1235 "Proxying method '{}' through function registry as '{}'",
1236 method, function_name
1237 );
1238
1239 let mut func_args = vec![value.clone()];
1241
1242 for arg in args {
1244 func_args.push(self.evaluate(arg, row_index)?);
1245 }
1246
1247 let func = self.function_registry.get(function_name).unwrap();
1249 return func.evaluate(&func_args);
1250 }
1251
1252 Err(anyhow!(
1255 "Method '{}' not found. It should be registered in the function registry.",
1256 method
1257 ))
1258 }
1259
1260 fn evaluate_case_expression(
1262 &mut self,
1263 when_branches: &[crate::sql::recursive_parser::WhenBranch],
1264 else_branch: &Option<Box<SqlExpression>>,
1265 row_index: usize,
1266 ) -> Result<DataValue> {
1267 debug!(
1268 "ArithmeticEvaluator: evaluating CASE expression for row {}",
1269 row_index
1270 );
1271
1272 for branch in when_branches {
1274 let condition_result = self.evaluate_condition_as_bool(&branch.condition, row_index)?;
1276
1277 if condition_result {
1278 debug!("CASE: WHEN condition matched, evaluating result expression");
1279 return self.evaluate(&branch.result, row_index);
1280 }
1281 }
1282
1283 if let Some(else_expr) = else_branch {
1285 debug!("CASE: No WHEN matched, evaluating ELSE expression");
1286 self.evaluate(else_expr, row_index)
1287 } else {
1288 debug!("CASE: No WHEN matched and no ELSE, returning NULL");
1289 Ok(DataValue::Null)
1290 }
1291 }
1292
1293 fn evaluate_simple_case_expression(
1295 &mut self,
1296 expr: &Box<SqlExpression>,
1297 when_branches: &[crate::sql::parser::ast::SimpleWhenBranch],
1298 else_branch: &Option<Box<SqlExpression>>,
1299 row_index: usize,
1300 ) -> Result<DataValue> {
1301 debug!(
1302 "ArithmeticEvaluator: evaluating simple CASE expression for row {}",
1303 row_index
1304 );
1305
1306 let case_value = self.evaluate(expr, row_index)?;
1308 debug!("Simple CASE: evaluated expression to {:?}", case_value);
1309
1310 for branch in when_branches {
1312 let when_value = self.evaluate(&branch.value, row_index)?;
1314
1315 if self.values_equal(&case_value, &when_value)? {
1317 debug!("Simple CASE: WHEN value matched, evaluating result expression");
1318 return self.evaluate(&branch.result, row_index);
1319 }
1320 }
1321
1322 if let Some(else_expr) = else_branch {
1324 debug!("Simple CASE: No WHEN matched, evaluating ELSE expression");
1325 self.evaluate(else_expr, row_index)
1326 } else {
1327 debug!("Simple CASE: No WHEN matched and no ELSE, returning NULL");
1328 Ok(DataValue::Null)
1329 }
1330 }
1331
1332 fn values_equal(&self, left: &DataValue, right: &DataValue) -> Result<bool> {
1334 match (left, right) {
1335 (DataValue::Null, DataValue::Null) => Ok(true),
1336 (DataValue::Null, _) | (_, DataValue::Null) => Ok(false),
1337 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(a == b),
1338 (DataValue::Float(a), DataValue::Float(b)) => Ok((a - b).abs() < f64::EPSILON),
1339 (DataValue::String(a), DataValue::String(b)) => Ok(a == b),
1340 (DataValue::Boolean(a), DataValue::Boolean(b)) => Ok(a == b),
1341 (DataValue::DateTime(a), DataValue::DateTime(b)) => Ok(a == b),
1342 (DataValue::Integer(a), DataValue::Float(b)) => {
1344 Ok((*a as f64 - b).abs() < f64::EPSILON)
1345 }
1346 (DataValue::Float(a), DataValue::Integer(b)) => {
1347 Ok((a - *b as f64).abs() < f64::EPSILON)
1348 }
1349 _ => Ok(false),
1350 }
1351 }
1352
1353 fn evaluate_condition_as_bool(
1355 &mut self,
1356 expr: &SqlExpression,
1357 row_index: usize,
1358 ) -> Result<bool> {
1359 let value = self.evaluate(expr, row_index)?;
1360
1361 match value {
1362 DataValue::Boolean(b) => Ok(b),
1363 DataValue::Integer(i) => Ok(i != 0),
1364 DataValue::Float(f) => Ok(f != 0.0),
1365 DataValue::Null => Ok(false),
1366 DataValue::String(s) => Ok(!s.is_empty()),
1367 DataValue::InternedString(s) => Ok(!s.is_empty()),
1368 _ => Ok(true), }
1370 }
1371
1372 fn evaluate_datetime_constructor(
1374 &self,
1375 year: i32,
1376 month: u32,
1377 day: u32,
1378 hour: Option<u32>,
1379 minute: Option<u32>,
1380 second: Option<u32>,
1381 ) -> Result<DataValue> {
1382 use chrono::{NaiveDate, TimeZone, Utc};
1383
1384 let date = NaiveDate::from_ymd_opt(year, month, day)
1386 .ok_or_else(|| anyhow!("Invalid date: {}-{}-{}", year, month, day))?;
1387
1388 let hour = hour.unwrap_or(0);
1390 let minute = minute.unwrap_or(0);
1391 let second = second.unwrap_or(0);
1392
1393 let naive_datetime = date
1394 .and_hms_opt(hour, minute, second)
1395 .ok_or_else(|| anyhow!("Invalid time: {}:{}:{}", hour, minute, second))?;
1396
1397 let datetime = Utc.from_utc_datetime(&naive_datetime);
1399
1400 let datetime_str = datetime.format("%Y-%m-%d %H:%M:%S%.3f").to_string();
1402 Ok(DataValue::String(datetime_str))
1403 }
1404
1405 fn evaluate_datetime_today(
1407 &self,
1408 hour: Option<u32>,
1409 minute: Option<u32>,
1410 second: Option<u32>,
1411 ) -> Result<DataValue> {
1412 use chrono::{TimeZone, Utc};
1413
1414 let today = Utc::now().date_naive();
1416
1417 let hour = hour.unwrap_or(0);
1419 let minute = minute.unwrap_or(0);
1420 let second = second.unwrap_or(0);
1421
1422 let naive_datetime = today
1423 .and_hms_opt(hour, minute, second)
1424 .ok_or_else(|| anyhow!("Invalid time: {}:{}:{}", hour, minute, second))?;
1425
1426 let datetime = Utc.from_utc_datetime(&naive_datetime);
1428
1429 let datetime_str = datetime.format("%Y-%m-%d %H:%M:%S%.3f").to_string();
1431 Ok(DataValue::String(datetime_str))
1432 }
1433}
1434
1435#[cfg(test)]
1436mod tests {
1437 use super::*;
1438 use crate::data::datatable::{DataColumn, DataRow};
1439
1440 fn create_test_table() -> DataTable {
1441 let mut table = DataTable::new("test");
1442 table.add_column(DataColumn::new("a"));
1443 table.add_column(DataColumn::new("b"));
1444 table.add_column(DataColumn::new("c"));
1445
1446 table
1447 .add_row(DataRow::new(vec![
1448 DataValue::Integer(10),
1449 DataValue::Float(2.5),
1450 DataValue::Integer(4),
1451 ]))
1452 .unwrap();
1453
1454 table
1455 }
1456
1457 #[test]
1458 fn test_evaluate_column() {
1459 let table = create_test_table();
1460 let mut evaluator = ArithmeticEvaluator::new(&table);
1461
1462 let expr = SqlExpression::Column(ColumnRef::unquoted("a".to_string()));
1463 let result = evaluator.evaluate(&expr, 0).unwrap();
1464 assert_eq!(result, DataValue::Integer(10));
1465 }
1466
1467 #[test]
1468 fn test_evaluate_number_literal() {
1469 let table = create_test_table();
1470 let mut evaluator = ArithmeticEvaluator::new(&table);
1471
1472 let expr = SqlExpression::NumberLiteral("42".to_string());
1473 let result = evaluator.evaluate(&expr, 0).unwrap();
1474 assert_eq!(result, DataValue::Integer(42));
1475
1476 let expr = SqlExpression::NumberLiteral("3.14".to_string());
1477 let result = evaluator.evaluate(&expr, 0).unwrap();
1478 assert_eq!(result, DataValue::Float(3.14));
1479 }
1480
1481 #[test]
1482 fn test_add_values() {
1483 let table = create_test_table();
1484 let mut evaluator = ArithmeticEvaluator::new(&table);
1485
1486 let result = evaluator
1488 .add_values(&DataValue::Integer(5), &DataValue::Integer(3))
1489 .unwrap();
1490 assert_eq!(result, DataValue::Integer(8));
1491
1492 let result = evaluator
1494 .add_values(&DataValue::Integer(5), &DataValue::Float(2.5))
1495 .unwrap();
1496 assert_eq!(result, DataValue::Float(7.5));
1497 }
1498
1499 #[test]
1500 fn test_multiply_values() {
1501 let table = create_test_table();
1502 let mut evaluator = ArithmeticEvaluator::new(&table);
1503
1504 let result = evaluator
1506 .multiply_values(&DataValue::Integer(4), &DataValue::Float(2.5))
1507 .unwrap();
1508 assert_eq!(result, DataValue::Float(10.0));
1509 }
1510
1511 #[test]
1512 fn test_divide_values() {
1513 let table = create_test_table();
1514 let mut evaluator = ArithmeticEvaluator::new(&table);
1515
1516 let result = evaluator
1518 .divide_values(&DataValue::Integer(10), &DataValue::Integer(2))
1519 .unwrap();
1520 assert_eq!(result, DataValue::Integer(5));
1521
1522 let result = evaluator
1524 .divide_values(&DataValue::Integer(10), &DataValue::Integer(3))
1525 .unwrap();
1526 assert_eq!(result, DataValue::Float(10.0 / 3.0));
1527 }
1528
1529 #[test]
1530 fn test_division_by_zero() {
1531 let table = create_test_table();
1532 let mut evaluator = ArithmeticEvaluator::new(&table);
1533
1534 let result = evaluator.divide_values(&DataValue::Integer(10), &DataValue::Integer(0));
1535 assert!(result.is_err());
1536 assert!(result.unwrap_err().to_string().contains("Division by zero"));
1537 }
1538
1539 #[test]
1540 fn test_binary_op_expression() {
1541 let table = create_test_table();
1542 let mut evaluator = ArithmeticEvaluator::new(&table);
1543
1544 let expr = SqlExpression::BinaryOp {
1546 left: Box::new(SqlExpression::Column(ColumnRef::unquoted("a".to_string()))),
1547 op: "*".to_string(),
1548 right: Box::new(SqlExpression::Column(ColumnRef::unquoted("b".to_string()))),
1549 };
1550
1551 let result = evaluator.evaluate(&expr, 0).unwrap();
1552 assert_eq!(result, DataValue::Float(25.0));
1553 }
1554}