1mod expression;
14mod functions;
15mod group_by;
16
17pub use expression::{
19 evaluate_expression_to_column, evaluate_expression_with_cached_column, extract_aggregates,
20};
21pub use group_by::columnar_group_by;
22pub use group_by::columnar_group_by_batch;
24use vibesql_ast::Expression;
25use vibesql_storage::Row;
26use vibesql_types::SqlValue;
27
28use super::{batch::ColumnarBatch, scan::ColumnarScan};
29use crate::{errors::ExecutorError, schema::CombinedSchema};
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum AggregateOp {
34 Sum,
35 Count,
36 Avg,
37 Min,
38 Max,
39}
40
41#[derive(Debug, Clone, PartialEq)]
43pub enum AggregateSource {
44 Column(usize),
46 Expression(Expression),
48 CountStar,
50}
51
52#[derive(Debug, Clone)]
54pub struct AggregateSpec {
55 pub op: AggregateOp,
56 pub source: AggregateSource,
57}
58
59pub fn compute_columnar_aggregate(
78 scan: &ColumnarScan,
79 column_idx: usize,
80 op: AggregateOp,
81 filter_bitmap: Option<&[bool]>,
82) -> Result<SqlValue, ExecutorError> {
83 {
85 use super::simd_aggregate::{
86 can_use_simd_for_column, simd_aggregate_f64, simd_aggregate_i64,
87 };
88
89 if let Some(is_integer) = can_use_simd_for_column(scan, column_idx) {
91 return if is_integer {
93 simd_aggregate_i64(scan, column_idx, op, filter_bitmap)
94 } else {
95 simd_aggregate_f64(scan, column_idx, op, filter_bitmap)
96 };
97 }
98 }
100
101 functions::compute_columnar_aggregate_impl(scan, column_idx, op, filter_bitmap)
103}
104
105pub fn compute_multiple_aggregates(
110 rows: &[Row],
111 aggregates: &[AggregateSpec],
112 filter_bitmap: Option<&[bool]>,
113 schema: Option<&CombinedSchema>,
114) -> Result<Vec<SqlValue>, ExecutorError> {
115 let scan = ColumnarScan::new(rows);
116 let mut results = Vec::with_capacity(aggregates.len());
117
118 for spec in aggregates {
119 let result = match &spec.source {
120 AggregateSource::Column(column_idx) => {
122 compute_columnar_aggregate(&scan, *column_idx, spec.op, filter_bitmap)?
123 }
124 AggregateSource::Expression(expr) => {
126 let schema = schema.ok_or_else(|| {
127 ExecutorError::UnsupportedExpression(
128 "Schema required for expression aggregates".to_string(),
129 )
130 })?;
131 expression::compute_expression_aggregate(
132 rows,
133 expr,
134 spec.op,
135 filter_bitmap,
136 schema,
137 )?
138 }
139 AggregateSource::CountStar => functions::compute_count(&scan, filter_bitmap)?,
141 };
142 results.push(result);
143 }
144
145 Ok(results)
146}
147
148pub fn compute_aggregates_from_batch(
171 batch: &ColumnarBatch,
172 aggregates: &[AggregateSpec],
173 schema: Option<&CombinedSchema>,
174) -> Result<Vec<SqlValue>, ExecutorError> {
175 if batch.row_count() == 0 {
177 return Ok(aggregates
178 .iter()
179 .map(|spec| match spec.op {
180 AggregateOp::Count => SqlValue::Integer(0),
181 _ => SqlValue::Null,
182 })
183 .collect());
184 }
185
186 let mut results = Vec::with_capacity(aggregates.len());
187
188 for spec in aggregates {
189 let result = match &spec.source {
190 AggregateSource::Column(column_idx) => {
192 functions::compute_batch_aggregate(batch, *column_idx, spec.op)?
193 }
194 AggregateSource::Expression(expr) => {
197 let schema = schema.ok_or_else(|| {
198 ExecutorError::UnsupportedExpression(
199 "Schema required for expression aggregates".to_string(),
200 )
201 })?;
202 match expression::compute_batch_expression_aggregate(batch, expr, spec.op, schema) {
205 Ok(value) => value,
206 Err(ExecutorError::UnsupportedExpression(_)) => {
207 let rows = batch.to_rows()?;
209 expression::compute_expression_aggregate(
210 &rows, expr, spec.op, None, schema,
211 )?
212 }
213 Err(other) => return Err(other),
214 }
215 }
216 AggregateSource::CountStar => SqlValue::Integer(batch.row_count() as i64),
218 };
219 results.push(result);
220 }
221
222 Ok(results)
223}
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228
229 fn make_test_rows() -> Vec<Row> {
230 vec![
231 Row::new(vec![SqlValue::Integer(10), SqlValue::Double(1.5)]),
232 Row::new(vec![SqlValue::Integer(20), SqlValue::Double(2.5)]),
233 Row::new(vec![SqlValue::Integer(30), SqlValue::Double(3.5)]),
234 ]
235 }
236
237 #[test]
238 fn test_sum_aggregate() {
239 let rows = make_test_rows();
240 let scan = ColumnarScan::new(&rows);
241
242 let result = functions::compute_sum(&scan, 0, None).unwrap();
243 assert_eq!(result, SqlValue::Integer(60));
244
245 let result = functions::compute_sum(&scan, 1, None).unwrap();
246 assert!(matches!(result, SqlValue::Double(sum) if (sum - 7.5).abs() < 0.001));
247 }
248
249 #[test]
250 fn test_count_aggregate() {
251 let rows = make_test_rows();
252 let scan = ColumnarScan::new(&rows);
253
254 let result = functions::compute_count(&scan, None).unwrap();
255 assert_eq!(result, SqlValue::Integer(3));
256 }
257
258 #[test]
259 fn test_sum_with_filter() {
260 let rows = make_test_rows();
261 let scan = ColumnarScan::new(&rows);
262 let filter = vec![true, false, true]; let result = functions::compute_sum(&scan, 0, Some(&filter)).unwrap();
265 assert_eq!(result, SqlValue::Integer(40));
266 }
267
268 #[test]
269 fn test_multiple_aggregates() {
270 let rows = make_test_rows();
271 let aggregates = vec![
272 AggregateSpec { op: AggregateOp::Sum, source: AggregateSource::Column(0) },
273 AggregateSpec { op: AggregateOp::Avg, source: AggregateSource::Column(1) },
274 ];
275
276 let results = compute_multiple_aggregates(&rows, &aggregates, None, None).unwrap();
277 assert_eq!(results.len(), 2);
278 assert_eq!(results[0], SqlValue::Integer(60));
279 assert!(matches!(results[1], SqlValue::Double(avg) if (avg - 2.5).abs() < 0.001));
280 }
281
282 #[test]
283 fn test_extract_aggregates_simple() {
284 use vibesql_catalog::{ColumnSchema, TableSchema};
285 use vibesql_types::DataType;
286
287 use crate::schema::CombinedSchema;
288
289 let schema = TableSchema::new(
291 "test".to_string(),
292 vec![
293 ColumnSchema::new("col1".to_string(), DataType::Integer, false),
294 ColumnSchema::new("col2".to_string(), DataType::DoublePrecision, false),
295 ],
296 );
297
298 let combined_schema = CombinedSchema::from_table("test".to_string(), schema);
299
300 let exprs = vec![Expression::AggregateFunction {
302 name: vibesql_ast::FunctionIdentifier::new("SUM"),
303 distinct: false,
304 args: vec![Expression::ColumnRef(vibesql_ast::ColumnIdentifier::simple("col1", false))],
305 order_by: None,
306 filter: None,
307 }];
308
309 let result = extract_aggregates(&exprs, &combined_schema);
310 assert!(result.is_some());
311 let aggregates = result.unwrap();
312 assert_eq!(aggregates.len(), 1);
313 assert!(matches!(aggregates[0].op, AggregateOp::Sum));
314 assert!(matches!(aggregates[0].source, AggregateSource::Column(0)));
315
316 let exprs = vec![Expression::AggregateFunction {
318 name: vibesql_ast::FunctionIdentifier::new("COUNT"),
319 distinct: false,
320 args: vec![Expression::Wildcard],
321 order_by: None,
322 filter: None,
323 }];
324
325 let result = extract_aggregates(&exprs, &combined_schema);
326 assert!(result.is_some());
327 let aggregates = result.unwrap();
328 assert_eq!(aggregates.len(), 1);
329 assert!(matches!(aggregates[0].op, AggregateOp::Count));
330 assert!(matches!(aggregates[0].source, AggregateSource::CountStar));
331
332 let exprs = vec![
334 Expression::AggregateFunction {
335 name: vibesql_ast::FunctionIdentifier::new("SUM"),
336 distinct: false,
337 args: vec![Expression::ColumnRef(vibesql_ast::ColumnIdentifier::simple(
338 "col1", false,
339 ))],
340 order_by: None,
341 filter: None,
342 },
343 Expression::AggregateFunction {
344 name: vibesql_ast::FunctionIdentifier::new("AVG"),
345 distinct: false,
346 args: vec![Expression::ColumnRef(vibesql_ast::ColumnIdentifier::simple(
347 "col2", false,
348 ))],
349 order_by: None,
350 filter: None,
351 },
352 ];
353
354 let result = extract_aggregates(&exprs, &combined_schema);
355 assert!(result.is_some());
356 let aggregates = result.unwrap();
357 assert_eq!(aggregates.len(), 2);
358 assert!(matches!(aggregates[0].op, AggregateOp::Sum));
359 assert!(matches!(aggregates[0].source, AggregateSource::Column(0)));
360 assert!(matches!(aggregates[1].op, AggregateOp::Avg));
361 assert!(matches!(aggregates[1].source, AggregateSource::Column(1)));
362 }
363
364 #[test]
365 fn test_extract_aggregates_unsupported() {
366 use vibesql_catalog::{ColumnSchema, TableSchema};
367 use vibesql_types::DataType;
368
369 use crate::schema::CombinedSchema;
370
371 let schema = TableSchema::new(
372 "test".to_string(),
373 vec![ColumnSchema::new("col1".to_string(), DataType::Integer, false)],
374 );
375
376 let combined_schema = CombinedSchema::from_table("test".to_string(), schema);
377
378 let exprs = vec![Expression::AggregateFunction {
380 name: vibesql_ast::FunctionIdentifier::new("SUM"),
381 distinct: true,
382 args: vec![Expression::ColumnRef(vibesql_ast::ColumnIdentifier::simple("col1", false))],
383 order_by: None,
384 filter: None,
385 }];
386
387 let result = extract_aggregates(&exprs, &combined_schema);
388 assert!(result.is_none());
389
390 let exprs =
392 vec![Expression::ColumnRef(vibesql_ast::ColumnIdentifier::simple("col1", false))];
393
394 let result = extract_aggregates(&exprs, &combined_schema);
395 assert!(result.is_none());
396
397 let exprs = vec![Expression::AggregateFunction {
399 name: vibesql_ast::FunctionIdentifier::new("SUM"),
400 distinct: false,
401 args: vec![Expression::ScalarSubquery(Box::new(vibesql_ast::SelectStmt {
402 with_clause: None,
403 distinct: false,
404 select_list: vec![],
405 into_table: None,
406 into_variables: None,
407 from: None,
408 where_clause: None,
409 group_by: None,
410 having: None,
411 order_by: None,
412 limit: None,
413 offset: None,
414 set_operation: None,
415 values: None,
416 }))],
417 order_by: None,
418 filter: None,
419 }];
420
421 let result = extract_aggregates(&exprs, &combined_schema);
422 assert!(result.is_none());
423 }
424
425 #[test]
426 fn test_extract_aggregates_with_expression() {
427 use vibesql_catalog::{ColumnSchema, TableSchema};
428 use vibesql_types::DataType;
429
430 use crate::schema::CombinedSchema;
431
432 let schema = TableSchema::new(
434 "test".to_string(),
435 vec![
436 ColumnSchema::new("price".to_string(), DataType::DoublePrecision, false),
437 ColumnSchema::new("discount".to_string(), DataType::DoublePrecision, false),
438 ],
439 );
440
441 let combined_schema = CombinedSchema::from_table("test".to_string(), schema);
442
443 let exprs = vec![Expression::AggregateFunction {
445 name: vibesql_ast::FunctionIdentifier::new("SUM"),
446 distinct: false,
447 args: vec![Expression::BinaryOp {
448 left: Box::new(Expression::ColumnRef(vibesql_ast::ColumnIdentifier::simple(
449 "price", false,
450 ))),
451 op: vibesql_ast::BinaryOperator::Multiply,
452 right: Box::new(Expression::ColumnRef(vibesql_ast::ColumnIdentifier::simple(
453 "discount", false,
454 ))),
455 }],
456 order_by: None,
457 filter: None,
458 }];
459
460 let result = extract_aggregates(&exprs, &combined_schema);
461 assert!(result.is_some());
462 let aggregates = result.unwrap();
463 assert_eq!(aggregates.len(), 1);
464 assert!(matches!(aggregates[0].op, AggregateOp::Sum));
465 assert!(matches!(aggregates[0].source, AggregateSource::Expression(_)));
466 }
467
468 #[test]
471 fn test_columnar_group_by_simple() {
472 let rows = vec![
475 Row::new(vec![SqlValue::Varchar(arcstr::ArcStr::from("A")), SqlValue::Double(100.0)]),
476 Row::new(vec![SqlValue::Varchar(arcstr::ArcStr::from("B")), SqlValue::Double(200.0)]),
477 Row::new(vec![SqlValue::Varchar(arcstr::ArcStr::from("A")), SqlValue::Double(150.0)]),
478 Row::new(vec![SqlValue::Varchar(arcstr::ArcStr::from("B")), SqlValue::Double(50.0)]),
479 ];
480
481 let group_cols = vec![0]; let agg_cols = vec![(1, AggregateOp::Sum)]; let result = columnar_group_by(&rows, &group_cols, &agg_cols, None).unwrap();
485
486 assert_eq!(result.len(), 2);
488
489 let mut sorted = result;
491 sorted.sort_by(|a, b| {
492 let a_key = a.get(0).unwrap();
493 let b_key = b.get(0).unwrap();
494 a_key.partial_cmp(b_key).unwrap()
495 });
496
497 assert_eq!(sorted[0].get(0), Some(&SqlValue::Varchar(arcstr::ArcStr::from("A"))));
499 assert!(
500 matches!(sorted[0].get(1), Some(&SqlValue::Double(sum)) if (sum - 250.0).abs() < 0.001)
501 );
502
503 assert_eq!(sorted[1].get(0), Some(&SqlValue::Varchar(arcstr::ArcStr::from("B"))));
505 assert!(
506 matches!(sorted[1].get(1), Some(&SqlValue::Double(sum)) if (sum - 250.0).abs() < 0.001)
507 );
508 }
509
510 #[test]
511 fn test_columnar_group_by_multiple_group_keys() {
512 let rows = vec![
515 Row::new(vec![SqlValue::Varchar(arcstr::ArcStr::from("A")), SqlValue::Integer(1)]),
516 Row::new(vec![SqlValue::Varchar(arcstr::ArcStr::from("A")), SqlValue::Integer(2)]),
517 Row::new(vec![SqlValue::Varchar(arcstr::ArcStr::from("B")), SqlValue::Integer(1)]),
518 Row::new(vec![SqlValue::Varchar(arcstr::ArcStr::from("A")), SqlValue::Integer(1)]),
519 Row::new(vec![SqlValue::Varchar(arcstr::ArcStr::from("B")), SqlValue::Integer(2)]),
520 ];
521
522 let group_cols = vec![0, 1]; let agg_cols = vec![(0, AggregateOp::Count)]; let result = columnar_group_by(&rows, &group_cols, &agg_cols, None).unwrap();
526
527 assert_eq!(result.len(), 4);
529
530 for row in &result {
532 let status = row.get(0).unwrap();
533 let category = row.get(1).unwrap();
534 let count = row.get(2).unwrap();
535
536 match (status, category) {
537 (SqlValue::Varchar(s), SqlValue::Integer(1)) if s.as_str() == "A" => {
538 assert_eq!(count, &SqlValue::Integer(2)); }
540 (SqlValue::Varchar(s), SqlValue::Integer(2)) if s.as_str() == "A" => {
541 assert_eq!(count, &SqlValue::Integer(1)); }
543 (SqlValue::Varchar(s), SqlValue::Integer(1)) if s.as_str() == "B" => {
544 assert_eq!(count, &SqlValue::Integer(1)); }
546 (SqlValue::Varchar(s), SqlValue::Integer(2)) if s.as_str() == "B" => {
547 assert_eq!(count, &SqlValue::Integer(1)); }
549 _ => panic!("Unexpected group key: {:?}, {:?}", status, category),
550 }
551 }
552 }
553
554 #[test]
555 fn test_columnar_group_by_multiple_aggregates() {
556 let rows = vec![
559 Row::new(vec![SqlValue::Integer(1), SqlValue::Double(100.0), SqlValue::Integer(10)]),
560 Row::new(vec![SqlValue::Integer(2), SqlValue::Double(200.0), SqlValue::Integer(20)]),
561 Row::new(vec![SqlValue::Integer(1), SqlValue::Double(150.0), SqlValue::Integer(15)]),
562 ];
563
564 let group_cols = vec![0]; let agg_cols = vec![
566 (1, AggregateOp::Sum), (2, AggregateOp::Avg), (0, AggregateOp::Count), ];
570
571 let result = columnar_group_by(&rows, &group_cols, &agg_cols, None).unwrap();
572
573 assert_eq!(result.len(), 2);
575
576 let mut sorted = result;
578 sorted.sort_by(|a, b| {
579 let a_key = a.get(0).unwrap();
580 let b_key = b.get(0).unwrap();
581 a_key.partial_cmp(b_key).unwrap()
582 });
583
584 assert_eq!(sorted[0].get(0), Some(&SqlValue::Integer(1)));
586 assert!(
587 matches!(sorted[0].get(1), Some(&SqlValue::Double(sum)) if (sum - 250.0).abs() < 0.001)
588 );
589 assert!(
590 matches!(sorted[0].get(2), Some(&SqlValue::Double(avg)) if (avg - 12.5).abs() < 0.001)
591 );
592 assert_eq!(sorted[0].get(3), Some(&SqlValue::Integer(2)));
593
594 assert_eq!(sorted[1].get(0), Some(&SqlValue::Integer(2)));
596 assert!(
597 matches!(sorted[1].get(1), Some(&SqlValue::Double(sum)) if (sum - 200.0).abs() < 0.001)
598 );
599 assert!(
600 matches!(sorted[1].get(2), Some(&SqlValue::Double(avg)) if (avg - 20.0).abs() < 0.001)
601 );
602 assert_eq!(sorted[1].get(3), Some(&SqlValue::Integer(1)));
603 }
604
605 #[test]
606 fn test_columnar_group_by_with_filter() {
607 let rows = vec![
610 Row::new(vec![SqlValue::Varchar(arcstr::ArcStr::from("A")), SqlValue::Double(100.0)]),
611 Row::new(vec![SqlValue::Varchar(arcstr::ArcStr::from("B")), SqlValue::Double(200.0)]),
612 Row::new(vec![SqlValue::Varchar(arcstr::ArcStr::from("A")), SqlValue::Double(150.0)]),
613 Row::new(vec![SqlValue::Varchar(arcstr::ArcStr::from("B")), SqlValue::Double(50.0)]),
614 ];
615
616 let filter = vec![false, true, true, false];
618
619 let group_cols = vec![0]; let agg_cols = vec![(1, AggregateOp::Sum)]; let result = columnar_group_by(&rows, &group_cols, &agg_cols, Some(&filter)).unwrap();
623
624 assert_eq!(result.len(), 2);
626
627 let mut sorted = result;
629 sorted.sort_by(|a, b| {
630 let a_key = a.get(0).unwrap();
631 let b_key = b.get(0).unwrap();
632 a_key.partial_cmp(b_key).unwrap()
633 });
634
635 assert_eq!(sorted[0].get(0), Some(&SqlValue::Varchar(arcstr::ArcStr::from("A"))));
637 assert!(
638 matches!(sorted[0].get(1), Some(&SqlValue::Double(sum)) if (sum - 150.0).abs() < 0.001)
639 );
640
641 assert_eq!(sorted[1].get(0), Some(&SqlValue::Varchar(arcstr::ArcStr::from("B"))));
643 assert!(
644 matches!(sorted[1].get(1), Some(&SqlValue::Double(sum)) if (sum - 200.0).abs() < 0.001)
645 );
646 }
647
648 #[test]
649 fn test_columnar_group_by_empty_input() {
650 let rows: Vec<Row> = vec![];
651 let group_cols = vec![0];
652 let agg_cols = vec![(1, AggregateOp::Sum)];
653
654 let result = columnar_group_by(&rows, &group_cols, &agg_cols, None).unwrap();
655
656 assert_eq!(result.len(), 0);
658 }
659
660 #[test]
661 fn test_columnar_group_by_null_in_group_key() {
662 let rows = vec![
664 Row::new(vec![SqlValue::Varchar(arcstr::ArcStr::from("A")), SqlValue::Double(100.0)]),
665 Row::new(vec![SqlValue::Null, SqlValue::Double(200.0)]),
666 Row::new(vec![SqlValue::Varchar(arcstr::ArcStr::from("A")), SqlValue::Double(150.0)]),
667 Row::new(vec![SqlValue::Null, SqlValue::Double(50.0)]),
668 ];
669
670 let group_cols = vec![0]; let agg_cols = vec![(1, AggregateOp::Sum)]; let result = columnar_group_by(&rows, &group_cols, &agg_cols, None).unwrap();
674
675 assert_eq!(result.len(), 2);
677
678 let a_group = result
680 .iter()
681 .find(|r| matches!(r.get(0), Some(SqlValue::Varchar(s)) if s.as_str() == "A"));
682 let null_group = result.iter().find(|r| matches!(r.get(0), Some(SqlValue::Null)));
683
684 assert!(a_group.is_some());
685 assert!(null_group.is_some());
686
687 assert!(
689 matches!(a_group.unwrap().get(1), Some(&SqlValue::Double(sum)) if (sum - 250.0).abs() < 0.001)
690 );
691
692 assert!(
694 matches!(null_group.unwrap().get(1), Some(&SqlValue::Double(sum)) if (sum - 250.0).abs() < 0.001)
695 );
696 }
697}