1mod aggregate;
35pub mod batch;
36mod executor;
37pub mod filter;
38mod scan;
39mod string_ops;
40
41pub mod simd_ops;
44
45mod simd_aggregate;
46pub mod simd_filter;
47mod simd_join;
48
49pub use aggregate::{
50 columnar_group_by, columnar_group_by_batch, compute_aggregates_from_batch,
51 compute_multiple_aggregates, evaluate_expression_to_column,
52 evaluate_expression_with_cached_column, extract_aggregates, AggregateOp, AggregateSource,
53 AggregateSpec,
54};
55pub use batch::{ColumnArray, ColumnarBatch};
56pub use executor::execute_columnar_batch;
57pub use filter::{
58 apply_columnar_filter, apply_columnar_filter_simd_streaming, create_filter_bitmap,
59 create_filter_bitmap_tree, evaluate_predicate_tree, extract_column_predicates,
60 extract_predicate_tree, ColumnPredicate, PredicateTree,
61};
62use log;
63pub use scan::ColumnarScan;
64pub use simd_aggregate::{can_use_simd_for_column, simd_aggregate_f64, simd_aggregate_i64};
65pub use simd_filter::{
66 simd_create_filter_mask, simd_create_filter_mask_auto, simd_create_filter_mask_packed,
67 simd_filter_batch, simd_filter_to_indices,
68};
69#[cfg(feature = "parallel")]
70pub use simd_filter::{simd_create_filter_mask_parallel, simd_filter_batch_parallel};
71pub use simd_join::columnar_hash_join_inner;
72pub use simd_ops::PackedMask;
73use vibesql_storage::Row;
74use vibesql_types::SqlValue;
75
76use crate::{errors::ExecutorError, schema::CombinedSchema};
77
78pub fn execute_columnar_aggregate(
114 rows: &[Row],
115 predicates: &[ColumnPredicate],
116 aggregates: &[aggregate::AggregateSpec],
117 schema: Option<&CombinedSchema>,
118) -> Result<Vec<Row>, ExecutorError> {
119 if rows.is_empty() {
122 let values: Vec<SqlValue> = aggregates
123 .iter()
124 .map(|spec| match spec.op {
125 aggregate::AggregateOp::Count => SqlValue::Integer(0),
126 _ => SqlValue::Null,
127 })
128 .collect();
129 return Ok(vec![Row::new(values)]);
130 }
131
132 #[cfg(feature = "profile-q6")]
134 let batch_start = std::time::Instant::now();
135
136 let batch = ColumnarBatch::from_rows(rows)?;
137
138 #[cfg(feature = "profile-q6")]
139 {
140 let batch_time = batch_start.elapsed();
141 eprintln!("[PROFILE-Q6] Phase 1 - Convert to batch: {:?}", batch_time);
142 }
143
144 #[cfg(feature = "profile-q6")]
146 let filter_start = std::time::Instant::now();
147
148 let filtered_batch =
149 if predicates.is_empty() { batch.clone() } else { simd_filter_batch(&batch, predicates)? };
150
151 #[cfg(feature = "profile-q6")]
152 {
153 let filter_time = filter_start.elapsed();
154 eprintln!(
155 "[PROFILE-Q6] Phase 2 - SIMD filter: {:?} ({}/{} rows passed)",
156 filter_time,
157 filtered_batch.row_count(),
158 rows.len()
159 );
160 }
161
162 #[cfg(feature = "profile-q6")]
164 let agg_start = std::time::Instant::now();
165
166 let results = compute_aggregates_from_batch(&filtered_batch, aggregates, schema)?;
168
169 #[cfg(feature = "profile-q6")]
170 {
171 let agg_time = agg_start.elapsed();
172 eprintln!(
173 "[PROFILE-Q6] Phase 3 - Batch-native aggregate: {:?} ({} aggregates)",
174 agg_time,
175 aggregates.len()
176 );
177 }
178
179 Ok(vec![Row::new(results)])
181}
182
183pub fn fast_aggregate_on_rows(
206 rows: &[Row],
207 predicates: &[ColumnPredicate],
208 aggregates: &[aggregate::AggregateSpec],
209) -> Result<Vec<Row>, ExecutorError> {
210 use aggregate::{AggregateOp, AggregateSource};
211
212 if rows.is_empty() {
214 let values: Vec<SqlValue> = aggregates
215 .iter()
216 .map(|spec| match spec.op {
217 AggregateOp::Count => SqlValue::Integer(0),
218 _ => SqlValue::Null,
219 })
220 .collect();
221 return Ok(vec![Row::new(values)]);
222 }
223
224 struct Accumulator {
226 sum_f64: f64,
227 sum_i64: i64,
228 count: i64,
229 min_f64: Option<f64>,
230 max_f64: Option<f64>,
231 min_i64: Option<i64>,
232 max_i64: Option<i64>,
233 is_integer: bool,
234 }
235
236 let mut accumulators: Vec<Accumulator> = aggregates
237 .iter()
238 .map(|_| Accumulator {
239 sum_f64: 0.0,
240 sum_i64: 0,
241 count: 0,
242 min_f64: None,
243 max_f64: None,
244 min_i64: None,
245 max_i64: None,
246 is_integer: true,
247 })
248 .collect();
249
250 for row in rows {
252 let passes_filter = predicates.iter().all(|pred| evaluate_predicate(row, pred));
254
255 if !passes_filter {
256 continue;
257 }
258
259 for (i, spec) in aggregates.iter().enumerate() {
261 let acc = &mut accumulators[i];
262
263 match &spec.source {
264 AggregateSource::CountStar => {
265 acc.count += 1;
266 }
267 AggregateSource::Column(col_idx) => {
268 if let Some(value) = row.get(*col_idx) {
269 if !matches!(value, SqlValue::Null) {
270 acc.count += 1;
271 match value {
272 SqlValue::Integer(v) => {
273 acc.sum_i64 += v;
274 acc.sum_f64 += *v as f64;
275 acc.min_i64 = Some(acc.min_i64.map_or(*v, |m| m.min(*v)));
276 acc.max_i64 = Some(acc.max_i64.map_or(*v, |m| m.max(*v)));
277 acc.min_f64 =
278 Some(acc.min_f64.map_or(*v as f64, |m| m.min(*v as f64)));
279 acc.max_f64 =
280 Some(acc.max_f64.map_or(*v as f64, |m| m.max(*v as f64)));
281 }
282 SqlValue::Double(v) => {
283 acc.is_integer = false;
284 acc.sum_f64 += v;
285 acc.min_f64 = Some(acc.min_f64.map_or(*v, |m| m.min(*v)));
286 acc.max_f64 = Some(acc.max_f64.map_or(*v, |m| m.max(*v)));
287 }
288 SqlValue::Float(v) => {
289 acc.is_integer = false;
290 acc.sum_f64 += *v as f64;
291 acc.min_f64 =
292 Some(acc.min_f64.map_or(*v as f64, |m| m.min(*v as f64)));
293 acc.max_f64 =
294 Some(acc.max_f64.map_or(*v as f64, |m| m.max(*v as f64)));
295 }
296 SqlValue::Bigint(v) => {
297 acc.sum_i64 += v;
298 acc.sum_f64 += *v as f64;
299 acc.min_i64 = Some(acc.min_i64.map_or(*v, |m| m.min(*v)));
300 acc.max_i64 = Some(acc.max_i64.map_or(*v, |m| m.max(*v)));
301 acc.min_f64 =
302 Some(acc.min_f64.map_or(*v as f64, |m| m.min(*v as f64)));
303 acc.max_f64 =
304 Some(acc.max_f64.map_or(*v as f64, |m| m.max(*v as f64)));
305 }
306 SqlValue::Numeric(v) => {
307 acc.is_integer = false;
308 acc.sum_f64 += v;
309 acc.min_f64 = Some(acc.min_f64.map_or(*v, |m| m.min(*v)));
310 acc.max_f64 = Some(acc.max_f64.map_or(*v, |m| m.max(*v)));
311 }
312 _ => {}
313 }
314 }
315 }
316 }
317 AggregateSource::Expression(expr) => {
318 if let Some(value) = eval_simple_expression(row, expr) {
321 acc.count += 1;
322 acc.is_integer = false;
323 acc.sum_f64 += value;
324 acc.min_f64 = Some(acc.min_f64.map_or(value, |m| m.min(value)));
325 acc.max_f64 = Some(acc.max_f64.map_or(value, |m| m.max(value)));
326 }
327 }
328 }
329 }
330 }
331
332 let values: Vec<SqlValue> = aggregates
334 .iter()
335 .zip(accumulators.iter())
336 .map(|(spec, acc)| match spec.op {
337 AggregateOp::Count => SqlValue::Integer(acc.count),
338 AggregateOp::Sum => {
339 if acc.count == 0 {
341 SqlValue::Null
342 } else if acc.is_integer {
343 SqlValue::Integer(acc.sum_i64)
344 } else {
345 SqlValue::Double(acc.sum_f64)
346 }
347 }
348 AggregateOp::Avg => {
349 if acc.count == 0 {
350 SqlValue::Null
351 } else {
352 SqlValue::Double(acc.sum_f64 / acc.count as f64)
353 }
354 }
355 AggregateOp::Min => {
356 if acc.is_integer {
357 acc.min_i64.map(SqlValue::Integer).unwrap_or(SqlValue::Null)
358 } else {
359 acc.min_f64.map(SqlValue::Double).unwrap_or(SqlValue::Null)
360 }
361 }
362 AggregateOp::Max => {
363 if acc.is_integer {
364 acc.max_i64.map(SqlValue::Integer).unwrap_or(SqlValue::Null)
365 } else {
366 acc.max_f64.map(SqlValue::Double).unwrap_or(SqlValue::Null)
367 }
368 }
369 })
370 .collect();
371
372 Ok(vec![Row::new(values)])
373}
374
375fn evaluate_predicate(row: &Row, predicate: &ColumnPredicate) -> bool {
377 match predicate {
378 ColumnPredicate::LessThan { column_idx, value } => row
379 .get(*column_idx)
380 .map(|v| compare_values(v, value) == std::cmp::Ordering::Less)
381 .unwrap_or(false),
382 ColumnPredicate::LessThanOrEqual { column_idx, value } => row
383 .get(*column_idx)
384 .map(|v| compare_values(v, value) != std::cmp::Ordering::Greater)
385 .unwrap_or(false),
386 ColumnPredicate::GreaterThan { column_idx, value } => row
387 .get(*column_idx)
388 .map(|v| compare_values(v, value) == std::cmp::Ordering::Greater)
389 .unwrap_or(false),
390 ColumnPredicate::GreaterThanOrEqual { column_idx, value } => row
391 .get(*column_idx)
392 .map(|v| compare_values(v, value) != std::cmp::Ordering::Less)
393 .unwrap_or(false),
394 ColumnPredicate::Equal { column_idx, value } => row
395 .get(*column_idx)
396 .map(|v| compare_values(v, value) == std::cmp::Ordering::Equal)
397 .unwrap_or(false),
398 ColumnPredicate::NotEqual { column_idx, value } => row
399 .get(*column_idx)
400 .map(|v| compare_values(v, value) != std::cmp::Ordering::Equal)
401 .unwrap_or(false),
402 ColumnPredicate::Between { column_idx, low, high } => row
403 .get(*column_idx)
404 .map(|v| {
405 compare_values(v, low) != std::cmp::Ordering::Less
406 && compare_values(v, high) != std::cmp::Ordering::Greater
407 })
408 .unwrap_or(false),
409 ColumnPredicate::Like { column_idx, pattern, negated } => {
410 let matches = row
412 .get(*column_idx)
413 .map(|v| {
414 if let SqlValue::Varchar(s) = v {
415 let pattern_str = pattern.as_str();
418 if let Some(inner) =
419 pattern_str.strip_prefix('%').and_then(|s| s.strip_suffix('%'))
420 {
421 s.contains(inner)
422 } else if let Some(suffix) = pattern_str.strip_prefix('%') {
423 s.ends_with(suffix)
424 } else if let Some(prefix) = pattern_str.strip_suffix('%') {
425 s.starts_with(prefix)
426 } else {
427 &**s == pattern_str
428 }
429 } else {
430 false
431 }
432 })
433 .unwrap_or(false);
434 if *negated {
435 !matches
436 } else {
437 matches
438 }
439 }
440 ColumnPredicate::InList { column_idx, values, negated, use_strict_type_ordering } => {
441 let matches = row
443 .get(*column_idx)
444 .map(|v| {
445 values.iter().any(|list_val| {
446 if *use_strict_type_ordering {
447 strict_type_equal(v, list_val)
449 } else {
450 compare_values(v, list_val) == std::cmp::Ordering::Equal
451 }
452 })
453 })
454 .unwrap_or(false);
455 if *negated {
456 !matches
457 } else {
458 matches
459 }
460 }
461 ColumnPredicate::ColumnCompare { left_column_idx, op, right_column_idx } => {
462 let left_val = row.get(*left_column_idx);
464 let right_val = row.get(*right_column_idx);
465 match (left_val, right_val) {
466 (Some(l), Some(r)) => {
467 use std::cmp::Ordering;
468 let cmp = compare_values(l, r);
469 match op {
470 filter::CompareOp::LessThan => cmp == Ordering::Less,
471 filter::CompareOp::GreaterThan => cmp == Ordering::Greater,
472 filter::CompareOp::LessThanOrEqual => cmp != Ordering::Greater,
473 filter::CompareOp::GreaterThanOrEqual => cmp != Ordering::Less,
474 filter::CompareOp::Equal => cmp == Ordering::Equal,
475 filter::CompareOp::NotEqual => cmp != Ordering::Equal,
476 }
477 }
478 _ => false, }
480 }
481 }
482}
483
484fn compare_values(a: &SqlValue, b: &SqlValue) -> std::cmp::Ordering {
486 use std::cmp::Ordering;
487
488 match (a, b) {
489 (SqlValue::Integer(a), SqlValue::Integer(b)) => a.cmp(b),
490 (SqlValue::Bigint(a), SqlValue::Bigint(b)) => a.cmp(b),
491 (SqlValue::Double(a), SqlValue::Double(b)) => a.partial_cmp(b).unwrap_or(Ordering::Equal),
492 (SqlValue::Float(a), SqlValue::Float(b)) => a.partial_cmp(b).unwrap_or(Ordering::Equal),
493 (SqlValue::Integer(a), SqlValue::Double(b)) => {
495 (*a as f64).partial_cmp(b).unwrap_or(Ordering::Equal)
496 }
497 (SqlValue::Double(a), SqlValue::Integer(b)) => {
498 a.partial_cmp(&(*b as f64)).unwrap_or(Ordering::Equal)
499 }
500 (SqlValue::Integer(a), SqlValue::Bigint(b)) => (*a).cmp(b),
501 (SqlValue::Bigint(a), SqlValue::Integer(b)) => a.cmp(&{ *b }),
502 (SqlValue::Varchar(a), SqlValue::Varchar(b)) => a.cmp(b),
504 (SqlValue::Date(a), SqlValue::Date(b)) => {
506 match a.year.cmp(&b.year) {
508 Ordering::Equal => match a.month.cmp(&b.month) {
509 Ordering::Equal => a.day.cmp(&b.day),
510 other => other,
511 },
512 other => other,
513 }
514 }
515 (SqlValue::Null, _) | (_, SqlValue::Null) => Ordering::Equal, _ => Ordering::Equal, }
520}
521
522fn strict_type_equal(a: &SqlValue, b: &SqlValue) -> bool {
531 if matches!(a, SqlValue::Null) || matches!(b, SqlValue::Null) {
533 return false;
534 }
535
536 let is_string = |v: &SqlValue| matches!(v, SqlValue::Varchar(_) | SqlValue::Character(_));
538
539 let is_numeric = |v: &SqlValue| {
541 matches!(
542 v,
543 SqlValue::Integer(_)
544 | SqlValue::Bigint(_)
545 | SqlValue::Smallint(_)
546 | SqlValue::Float(_)
547 | SqlValue::Real(_)
548 | SqlValue::Double(_)
549 | SqlValue::Numeric(_)
550 )
551 };
552
553 if (is_string(a) && is_numeric(b)) || (is_numeric(a) && is_string(b)) {
555 return false;
556 }
557
558 match (a, b) {
560 (SqlValue::Integer(x), SqlValue::Integer(y)) => x == y,
562 (SqlValue::Bigint(x), SqlValue::Bigint(y)) => x == y,
563 (SqlValue::Smallint(x), SqlValue::Smallint(y)) => x == y,
564
565 (SqlValue::Varchar(x), SqlValue::Varchar(y))
567 | (SqlValue::Character(x), SqlValue::Character(y)) => x == y,
568 (SqlValue::Varchar(x), SqlValue::Character(y))
569 | (SqlValue::Character(x), SqlValue::Varchar(y)) => x == y,
570
571 (SqlValue::Float(x), SqlValue::Float(y)) => (x - y).abs() < f32::EPSILON,
573 (SqlValue::Double(x), SqlValue::Double(y)) => (x - y).abs() < f64::EPSILON,
574 (SqlValue::Real(x), SqlValue::Real(y)) => (x - y).abs() < f64::EPSILON,
575
576 (SqlValue::Integer(x), SqlValue::Bigint(y))
578 | (SqlValue::Bigint(y), SqlValue::Integer(x)) => *x as i64 == *y,
579
580 (SqlValue::Float(x), SqlValue::Integer(y))
582 | (SqlValue::Integer(y), SqlValue::Float(x)) => (*x as f64 - *y as f64).abs() < f64::EPSILON,
583 (SqlValue::Double(x), SqlValue::Integer(y))
584 | (SqlValue::Integer(y), SqlValue::Double(x)) => (*x - *y as f64).abs() < f64::EPSILON,
585 (SqlValue::Real(x), SqlValue::Integer(y))
586 | (SqlValue::Integer(y), SqlValue::Real(x)) => (*x - *y as f64).abs() < f64::EPSILON,
587
588 (SqlValue::Float(x), SqlValue::Double(y))
590 | (SqlValue::Double(y), SqlValue::Float(x)) => (*x as f64 - *y).abs() < f64::EPSILON,
591 (SqlValue::Float(x), SqlValue::Real(y))
592 | (SqlValue::Real(y), SqlValue::Float(x)) => (*x as f64 - *y).abs() < f64::EPSILON,
593 (SqlValue::Double(x), SqlValue::Real(y))
594 | (SqlValue::Real(y), SqlValue::Double(x)) => (*x - *y).abs() < f64::EPSILON,
595
596 (SqlValue::Numeric(x), SqlValue::Integer(y))
598 | (SqlValue::Integer(y), SqlValue::Numeric(x)) => (*x - *y as f64).abs() < f64::EPSILON,
599 (SqlValue::Numeric(x), SqlValue::Bigint(y))
600 | (SqlValue::Bigint(y), SqlValue::Numeric(x)) => (*x - *y as f64).abs() < f64::EPSILON,
601 (SqlValue::Numeric(x), SqlValue::Smallint(y))
602 | (SqlValue::Smallint(y), SqlValue::Numeric(x)) => (*x - *y as f64).abs() < f64::EPSILON,
603
604 (SqlValue::Numeric(x), SqlValue::Float(y))
606 | (SqlValue::Float(y), SqlValue::Numeric(x)) => (*x - *y as f64).abs() < f64::EPSILON,
607 (SqlValue::Numeric(x), SqlValue::Double(y))
608 | (SqlValue::Double(y), SqlValue::Numeric(x)) => (*x - *y).abs() < f64::EPSILON,
609 (SqlValue::Numeric(x), SqlValue::Real(y))
610 | (SqlValue::Real(y), SqlValue::Numeric(x)) => (*x - *y).abs() < f64::EPSILON,
611 (SqlValue::Numeric(x), SqlValue::Numeric(y)) => (*x - *y).abs() < f64::EPSILON,
612
613 _ => false,
615 }
616}
617
618#[allow(clippy::only_used_in_recursion)]
620fn eval_simple_expression(row: &Row, expr: &vibesql_ast::Expression) -> Option<f64> {
621 use vibesql_ast::{BinaryOperator, Expression};
622
623 match expr {
624 Expression::BinaryOp { left, op, right } => {
625 let left_val = eval_simple_expression(row, left)?;
626 let right_val = eval_simple_expression(row, right)?;
627 match op {
628 BinaryOperator::Multiply => Some(left_val * right_val),
629 BinaryOperator::Divide => Some(left_val / right_val),
630 BinaryOperator::Plus => Some(left_val + right_val),
631 BinaryOperator::Minus => Some(left_val - right_val),
632 _ => None,
633 }
634 }
635 Expression::ColumnRef(col_id) => {
636 log::debug!(
640 "fast_aggregate_on_rows: ColumnRef '{}' requires schema resolution, skipping fast path",
641 col_id.column_canonical()
642 );
643 None
644 }
645 Expression::Literal(val) => match val {
646 SqlValue::Integer(v) => Some(*v as f64),
647 SqlValue::Double(v) => Some(*v),
648 SqlValue::Float(v) => Some(*v as f64),
649 SqlValue::Bigint(v) => Some(*v as f64),
650 SqlValue::Numeric(v) => Some(*v),
651 _ => None,
652 },
653 _ => None,
654 }
655}
656
657pub fn execute_columnar(
676 rows: &[Row],
677 filter: Option<&vibesql_ast::Expression>,
678 aggregates: &[vibesql_ast::Expression],
679 schema: &CombinedSchema,
680) -> Option<Result<Vec<Row>, ExecutorError>> {
681 log::debug!(" Executing columnar query with {} rows", rows.len());
682
683 let predicates = if let Some(filter_expr) = filter {
685 match extract_column_predicates(filter_expr, schema) {
686 Some(preds) => {
687 log::debug!(" ✓ Extracted {} column predicates for SIMD filtering", preds.len());
688 preds
689 }
690 None => {
691 log::debug!(" ✗ Filter too complex for columnar optimization");
692 return None; }
694 }
695 } else {
696 log::debug!(" No filter predicates");
697 vec![] };
699
700 let agg_specs = match extract_aggregates(aggregates, schema) {
702 Some(specs) => {
703 log::debug!(" ✓ Extracted {} aggregate operations", specs.len());
704 for (i, spec) in specs.iter().enumerate() {
705 log::debug!(" Aggregate {}: {:?}", i + 1, spec.op);
706 }
707 specs
708 }
709 None => {
710 log::debug!(" ✗ Aggregates too complex for columnar optimization");
711 return None; }
713 };
714
715 let needs_schema = agg_specs
717 .iter()
718 .any(|spec| matches!(spec.source, aggregate::AggregateSource::Expression(_)));
719 let schema_ref = if needs_schema { Some(schema) } else { None };
720
721 log::debug!(" Executing SIMD-accelerated columnar aggregation");
722 Some(execute_columnar_aggregate(rows, &predicates, &agg_specs, schema_ref))
723}
724
725#[cfg(test)]
726mod tests {
727 use vibesql_types::Date;
728
729 use super::*;
730
731 #[test]
733 fn test_columnar_pipeline_filtered_sum() {
734 let rows = vec![
742 Row::new(vec![
743 SqlValue::Integer(10), SqlValue::Double(100.0), SqlValue::Double(0.06), SqlValue::Date(Date::new(1994, 6, 1).unwrap()),
747 ]),
748 Row::new(vec![
749 SqlValue::Integer(25), SqlValue::Double(200.0),
751 SqlValue::Double(0.06),
752 SqlValue::Date(Date::new(1994, 7, 1).unwrap()),
753 ]),
754 Row::new(vec![
755 SqlValue::Integer(15), SqlValue::Double(150.0),
757 SqlValue::Double(0.05), SqlValue::Date(Date::new(1994, 8, 1).unwrap()),
759 ]),
760 Row::new(vec![
761 SqlValue::Integer(20), SqlValue::Double(180.0),
763 SqlValue::Double(0.08), SqlValue::Date(Date::new(1994, 9, 1).unwrap()),
765 ]),
766 ];
767
768 let predicates = vec![
770 ColumnPredicate::LessThan { column_idx: 0, value: SqlValue::Integer(24) },
771 ColumnPredicate::Between {
772 column_idx: 2,
773 low: SqlValue::Double(0.05),
774 high: SqlValue::Double(0.07),
775 },
776 ];
777
778 let aggregates = vec![
780 AggregateSpec { op: AggregateOp::Sum, source: AggregateSource::Column(1) }, AggregateSpec { op: AggregateOp::Count, source: AggregateSource::Column(0) }, ];
783
784 let result = execute_columnar_aggregate(&rows, &predicates, &aggregates, None).unwrap();
785
786 assert_eq!(result.len(), 1);
787 let result_row = &result[0];
788
789 assert!(
792 matches!(result_row.get(0), Some(&SqlValue::Double(sum)) if (sum - 250.0).abs() < 0.001)
793 );
794 assert_eq!(result_row.get(1), Some(&SqlValue::Integer(2)));
796 }
797
798 #[test]
800 fn test_columnar_pipeline_no_filter() {
801 let rows = vec![
802 Row::new(vec![SqlValue::Integer(10), SqlValue::Double(1.5)]),
803 Row::new(vec![SqlValue::Integer(20), SqlValue::Double(2.5)]),
804 Row::new(vec![SqlValue::Integer(30), SqlValue::Double(3.5)]),
805 ];
806
807 let predicates = vec![];
808 let aggregates = vec![
809 AggregateSpec { op: AggregateOp::Sum, source: AggregateSource::Column(0) },
810 AggregateSpec { op: AggregateOp::Avg, source: AggregateSource::Column(1) },
811 AggregateSpec { op: AggregateOp::Max, source: AggregateSource::Column(0) },
812 ];
813
814 let result = execute_columnar_aggregate(&rows, &predicates, &aggregates, None).unwrap();
815
816 assert_eq!(result.len(), 1);
817 let result_row = &result[0];
818
819 assert_eq!(result_row.get(0), Some(&SqlValue::Integer(60)));
821 assert!(
823 matches!(result_row.get(1), Some(&SqlValue::Double(avg)) if (avg - 2.5).abs() < 0.001)
824 );
825 assert_eq!(result_row.get(2), Some(&SqlValue::Integer(30)));
827 }
828
829 #[test]
831 fn test_columnar_pipeline_empty_result() {
832 let rows =
833 vec![Row::new(vec![SqlValue::Integer(100)]), Row::new(vec![SqlValue::Integer(200)])];
834
835 let predicates =
837 vec![ColumnPredicate::LessThan { column_idx: 0, value: SqlValue::Integer(50) }];
838
839 let aggregates = vec![
840 AggregateSpec { op: AggregateOp::Sum, source: AggregateSource::Column(0) },
841 AggregateSpec { op: AggregateOp::Count, source: AggregateSource::Column(0) },
842 ];
843
844 let result = execute_columnar_aggregate(&rows, &predicates, &aggregates, None).unwrap();
845
846 assert_eq!(result.len(), 1);
847 let result_row = &result[0];
848
849 assert_eq!(result_row.get(0), Some(&SqlValue::Null));
851 assert_eq!(result_row.get(1), Some(&SqlValue::Integer(0)));
853 }
854
855 use vibesql_ast::{BinaryOperator, Expression};
858 use vibesql_catalog::{ColumnSchema, TableSchema};
859 use vibesql_types::DataType;
860
861 use crate::schema::CombinedSchema;
862
863 fn make_test_schema() -> CombinedSchema {
864 let schema = TableSchema::new(
865 "test".to_string(),
866 vec![
867 ColumnSchema::new("quantity".to_string(), DataType::Integer, false),
868 ColumnSchema::new("price".to_string(), DataType::DoublePrecision, false),
869 ],
870 );
871 CombinedSchema::from_table("test".to_string(), schema)
872 }
873
874 fn make_test_rows_for_ast() -> Vec<Row> {
875 vec![
876 Row::new(vec![SqlValue::Integer(10), SqlValue::Double(1.5)]),
877 Row::new(vec![SqlValue::Integer(20), SqlValue::Double(2.5)]),
878 Row::new(vec![SqlValue::Integer(30), SqlValue::Double(3.5)]),
879 Row::new(vec![SqlValue::Integer(40), SqlValue::Double(4.5)]),
880 ]
881 }
882
883 #[test]
884 fn test_execute_columnar_simple_aggregate() {
885 let rows = make_test_rows_for_ast();
886 let schema = make_test_schema();
887
888 let aggregates = vec![Expression::AggregateFunction {
890 name: vibesql_ast::FunctionIdentifier::new("SUM"),
891 distinct: false,
892 args: vec![Expression::ColumnRef(vibesql_ast::ColumnIdentifier::simple(
893 "price", false,
894 ))],
895 order_by: None,
896 filter: None,
897 }];
898
899 let result = execute_columnar(&rows, None, &aggregates, &schema);
900 assert!(result.is_some());
901
902 let rows_result = result.unwrap();
903 assert!(rows_result.is_ok());
904
905 let result_rows = rows_result.unwrap();
906 assert_eq!(result_rows.len(), 1);
907 assert_eq!(result_rows[0].len(), 1);
908
909 if let Some(SqlValue::Double(sum)) = result_rows[0].get(0) {
911 assert!((sum - 12.0).abs() < 0.001);
912 } else {
913 panic!("Expected Numeric value for SUM");
914 }
915 }
916
917 #[test]
918 fn test_execute_columnar_with_filter() {
919 let rows = make_test_rows_for_ast();
920 let schema = make_test_schema();
921
922 let filter = Expression::BinaryOp {
924 left: Box::new(Expression::ColumnRef(vibesql_ast::ColumnIdentifier::simple(
925 "quantity", false,
926 ))),
927 op: BinaryOperator::LessThan,
928 right: Box::new(Expression::Literal(SqlValue::Integer(25))),
929 };
930
931 let aggregates = vec![Expression::AggregateFunction {
932 name: vibesql_ast::FunctionIdentifier::new("SUM"),
933 distinct: false,
934 args: vec![Expression::ColumnRef(vibesql_ast::ColumnIdentifier::simple(
935 "price", false,
936 ))],
937 order_by: None,
938 filter: None,
939 }];
940
941 let result = execute_columnar(&rows, Some(&filter), &aggregates, &schema);
942 assert!(result.is_some());
943
944 let rows_result = result.unwrap();
945 assert!(rows_result.is_ok());
946
947 let result_rows = rows_result.unwrap();
948 assert_eq!(result_rows.len(), 1);
949 assert_eq!(result_rows[0].len(), 1);
950
951 if let Some(SqlValue::Double(sum)) = result_rows[0].get(0) {
953 assert!((sum - 4.0).abs() < 0.001);
954 } else {
955 panic!("Expected Numeric value for SUM");
956 }
957 }
958
959 #[test]
960 fn test_execute_columnar_multiple_aggregates() {
961 let rows = make_test_rows_for_ast();
962 let schema = make_test_schema();
963
964 let aggregates = vec![
966 Expression::AggregateFunction {
967 name: vibesql_ast::FunctionIdentifier::new("SUM"),
968 distinct: false,
969 args: vec![Expression::ColumnRef(vibesql_ast::ColumnIdentifier::simple(
970 "price", false,
971 ))],
972 order_by: None,
973 filter: None,
974 },
975 Expression::AggregateFunction {
976 name: vibesql_ast::FunctionIdentifier::new("COUNT"),
977 distinct: false,
978 args: vec![Expression::Wildcard],
979 order_by: None,
980 filter: None,
981 },
982 Expression::AggregateFunction {
983 name: vibesql_ast::FunctionIdentifier::new("AVG"),
984 distinct: false,
985 args: vec![Expression::ColumnRef(vibesql_ast::ColumnIdentifier::simple(
986 "quantity", false,
987 ))],
988 order_by: None,
989 filter: None,
990 },
991 ];
992
993 let result = execute_columnar(&rows, None, &aggregates, &schema);
994 assert!(result.is_some());
995
996 let rows_result = result.unwrap();
997 assert!(rows_result.is_ok());
998
999 let result_rows = rows_result.unwrap();
1000 assert_eq!(result_rows.len(), 1);
1001 assert_eq!(result_rows[0].len(), 3);
1002
1003 if let Some(SqlValue::Double(sum)) = result_rows[0].get(0) {
1005 assert!((sum - 12.0).abs() < 0.001);
1006 } else {
1007 panic!("Expected Numeric value for SUM");
1008 }
1009
1010 assert_eq!(result_rows[0].get(1), Some(&SqlValue::Integer(4)));
1012
1013 if let Some(SqlValue::Double(avg)) = result_rows[0].get(2) {
1015 assert!((avg - 25.0).abs() < 0.001);
1016 } else {
1017 panic!("Expected Double value for AVG");
1018 }
1019 }
1020
1021 #[test]
1022 fn test_execute_columnar_unsupported_distinct() {
1023 let rows = make_test_rows_for_ast();
1024 let schema = make_test_schema();
1025
1026 let aggregates = vec![Expression::AggregateFunction {
1028 name: vibesql_ast::FunctionIdentifier::new("SUM"),
1029 distinct: true,
1030 args: vec![Expression::ColumnRef(vibesql_ast::ColumnIdentifier::simple(
1031 "price", false,
1032 ))],
1033 order_by: None,
1034 filter: None,
1035 }];
1036
1037 let result = execute_columnar(&rows, None, &aggregates, &schema);
1038 assert!(result.is_none());
1039 }
1040
1041 #[test]
1042 fn test_execute_columnar_unsupported_complex_filter() {
1043 let rows = make_test_rows_for_ast();
1044 let schema = make_test_schema();
1045
1046 let filter = Expression::ScalarSubquery(Box::new(vibesql_ast::SelectStmt {
1048 with_clause: None,
1049 distinct: false,
1050 select_list: vec![],
1051 into_table: None,
1052 into_variables: None,
1053 from: None,
1054 where_clause: None,
1055 group_by: None,
1056 having: None,
1057 order_by: None,
1058 limit: None,
1059 offset: None,
1060 set_operation: None,
1061 values: None,
1062 }));
1063
1064 let aggregates = vec![Expression::AggregateFunction {
1065 name: vibesql_ast::FunctionIdentifier::new("SUM"),
1066 distinct: false,
1067 args: vec![Expression::ColumnRef(vibesql_ast::ColumnIdentifier::simple(
1068 "price", false,
1069 ))],
1070 order_by: None,
1071 filter: None,
1072 }];
1073
1074 let result = execute_columnar(&rows, Some(&filter), &aggregates, &schema);
1075 assert!(result.is_none());
1076 }
1077}