1mod expression;
14mod functions;
15mod group_by;
16
17use crate::errors::ExecutorError;
18use crate::schema::CombinedSchema;
19use vibesql_ast::Expression;
20use vibesql_storage::Row;
21use vibesql_types::SqlValue;
22
23use super::batch::ColumnarBatch;
24use super::scan::ColumnarScan;
25
26pub use expression::extract_aggregates;
28pub use expression::evaluate_expression_to_column;
29pub use expression::evaluate_expression_with_cached_column;
30pub use group_by::columnar_group_by;
31pub use group_by::columnar_group_by_batch;
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub enum AggregateOp {
37 Sum,
38 Count,
39 Avg,
40 Min,
41 Max,
42}
43
44#[derive(Debug, Clone)]
46pub enum AggregateSource {
47 Column(usize),
49 Expression(Expression),
51 CountStar,
53}
54
55#[derive(Debug, Clone)]
57pub struct AggregateSpec {
58 pub op: AggregateOp,
59 pub source: AggregateSource,
60}
61
62pub fn compute_columnar_aggregate(
81 scan: &ColumnarScan,
82 column_idx: usize,
83 op: AggregateOp,
84 filter_bitmap: Option<&[bool]>,
85) -> Result<SqlValue, ExecutorError> {
86 {
88 use super::simd_aggregate::{can_use_simd_for_column, simd_aggregate_f64, simd_aggregate_i64};
89
90 if let Some(is_integer) = can_use_simd_for_column(scan, column_idx) {
92 return if is_integer {
94 simd_aggregate_i64(scan, column_idx, op, filter_bitmap)
95 } else {
96 simd_aggregate_f64(scan, column_idx, op, filter_bitmap)
97 };
98 }
99 }
101
102 functions::compute_columnar_aggregate_impl(scan, column_idx, op, filter_bitmap)
104}
105
106pub fn compute_multiple_aggregates(
111 rows: &[Row],
112 aggregates: &[AggregateSpec],
113 filter_bitmap: Option<&[bool]>,
114 schema: Option<&CombinedSchema>,
115) -> Result<Vec<SqlValue>, ExecutorError> {
116 let scan = ColumnarScan::new(rows);
117 let mut results = Vec::with_capacity(aggregates.len());
118
119 for spec in aggregates {
120 let result = match &spec.source {
121 AggregateSource::Column(column_idx) => {
123 compute_columnar_aggregate(&scan, *column_idx, spec.op, filter_bitmap)?
124 }
125 AggregateSource::Expression(expr) => {
127 let schema = schema.ok_or_else(|| {
128 ExecutorError::UnsupportedExpression(
129 "Schema required for expression aggregates".to_string()
130 )
131 })?;
132 expression::compute_expression_aggregate(rows, expr, spec.op, filter_bitmap, schema)?
133 }
134 AggregateSource::CountStar => {
136 functions::compute_count(&scan, filter_bitmap)?
137 }
138 };
139 results.push(result);
140 }
141
142 Ok(results)
143}
144
145pub fn compute_aggregates_from_batch(
168 batch: &ColumnarBatch,
169 aggregates: &[AggregateSpec],
170 schema: Option<&CombinedSchema>,
171) -> Result<Vec<SqlValue>, ExecutorError> {
172 if batch.row_count() == 0 {
174 return Ok(aggregates
175 .iter()
176 .map(|spec| match spec.op {
177 AggregateOp::Count => SqlValue::Integer(0),
178 _ => SqlValue::Null,
179 })
180 .collect());
181 }
182
183 let mut results = Vec::with_capacity(aggregates.len());
184
185 for spec in aggregates {
186 let result = match &spec.source {
187 AggregateSource::Column(column_idx) => {
189 functions::compute_batch_aggregate(batch, *column_idx, spec.op)?
190 }
191 AggregateSource::Expression(expr) => {
194 let schema = schema.ok_or_else(|| {
195 ExecutorError::UnsupportedExpression(
196 "Schema required for expression aggregates".to_string()
197 )
198 })?;
199 match expression::compute_batch_expression_aggregate(batch, expr, spec.op, schema) {
202 Ok(value) => value,
203 Err(ExecutorError::UnsupportedExpression(_)) => {
204 let rows = batch.to_rows()?;
206 expression::compute_expression_aggregate(&rows, expr, spec.op, None, schema)?
207 }
208 Err(other) => return Err(other),
209 }
210 }
211 AggregateSource::CountStar => {
213 SqlValue::Integer(batch.row_count() as i64)
214 }
215 };
216 results.push(result);
217 }
218
219 Ok(results)
220}
221
222#[cfg(test)]
223mod tests {
224 use super::*;
225
226 fn make_test_rows() -> Vec<Row> {
227 vec![
228 Row::new(vec![SqlValue::Integer(10), SqlValue::Double(1.5)]),
229 Row::new(vec![SqlValue::Integer(20), SqlValue::Double(2.5)]),
230 Row::new(vec![SqlValue::Integer(30), SqlValue::Double(3.5)]),
231 ]
232 }
233
234 #[test]
235 fn test_sum_aggregate() {
236 let rows = make_test_rows();
237 let scan = ColumnarScan::new(&rows);
238
239 let result = functions::compute_sum(&scan, 0, None).unwrap();
240 assert_eq!(result, SqlValue::Integer(60));
241
242 let result = functions::compute_sum(&scan, 1, None).unwrap();
243 assert!(matches!(result, SqlValue::Double(sum) if (sum - 7.5).abs() < 0.001));
244 }
245
246 #[test]
247 fn test_count_aggregate() {
248 let rows = make_test_rows();
249 let scan = ColumnarScan::new(&rows);
250
251 let result = functions::compute_count(&scan, None).unwrap();
252 assert_eq!(result, SqlValue::Integer(3));
253 }
254
255 #[test]
256 fn test_sum_with_filter() {
257 let rows = make_test_rows();
258 let scan = ColumnarScan::new(&rows);
259 let filter = vec![true, false, true]; let result = functions::compute_sum(&scan, 0, Some(&filter)).unwrap();
262 assert_eq!(result, SqlValue::Integer(40));
263 }
264
265 #[test]
266 fn test_multiple_aggregates() {
267 let rows = make_test_rows();
268 let aggregates = vec![
269 AggregateSpec { op: AggregateOp::Sum, source: AggregateSource::Column(0) },
270 AggregateSpec { op: AggregateOp::Avg, source: AggregateSource::Column(1) },
271 ];
272
273 let results = compute_multiple_aggregates(&rows, &aggregates, None, None).unwrap();
274 assert_eq!(results.len(), 2);
275 assert_eq!(results[0], SqlValue::Integer(60));
276 assert!(matches!(results[1], SqlValue::Double(avg) if (avg - 2.5).abs() < 0.001));
277 }
278
279 #[test]
280 fn test_extract_aggregates_simple() {
281 use crate::schema::CombinedSchema;
282 use vibesql_catalog::{ColumnSchema, TableSchema};
283 use vibesql_types::DataType;
284
285 let schema = TableSchema::new(
287 "test".to_string(),
288 vec![
289 ColumnSchema::new("col1".to_string(), DataType::Integer, false),
290 ColumnSchema::new("col2".to_string(), DataType::DoublePrecision, false),
291 ],
292 );
293
294 let combined_schema = CombinedSchema::from_table("test".to_string(), schema);
295
296 let exprs = vec![Expression::AggregateFunction {
298 name: "SUM".to_string(),
299 distinct: false,
300 args: vec![Expression::ColumnRef {
301 table: None,
302 column: "col1".to_string(),
303 }],
304 }];
305
306 let result = extract_aggregates(&exprs, &combined_schema);
307 assert!(result.is_some());
308 let aggregates = result.unwrap();
309 assert_eq!(aggregates.len(), 1);
310 assert!(matches!(aggregates[0].op, AggregateOp::Sum));
311 assert!(matches!(aggregates[0].source, AggregateSource::Column(0)));
312
313 let exprs = vec![Expression::AggregateFunction {
315 name: "COUNT".to_string(),
316 distinct: false,
317 args: vec![Expression::Wildcard],
318 }];
319
320 let result = extract_aggregates(&exprs, &combined_schema);
321 assert!(result.is_some());
322 let aggregates = result.unwrap();
323 assert_eq!(aggregates.len(), 1);
324 assert!(matches!(aggregates[0].op, AggregateOp::Count));
325 assert!(matches!(aggregates[0].source, AggregateSource::CountStar));
326
327 let exprs = vec![
329 Expression::AggregateFunction {
330 name: "SUM".to_string(),
331 distinct: false,
332 args: vec![Expression::ColumnRef {
333 table: None,
334 column: "col1".to_string(),
335 }],
336 },
337 Expression::AggregateFunction {
338 name: "AVG".to_string(),
339 distinct: false,
340 args: vec![Expression::ColumnRef {
341 table: None,
342 column: "col2".to_string(),
343 }],
344 },
345 ];
346
347 let result = extract_aggregates(&exprs, &combined_schema);
348 assert!(result.is_some());
349 let aggregates = result.unwrap();
350 assert_eq!(aggregates.len(), 2);
351 assert!(matches!(aggregates[0].op, AggregateOp::Sum));
352 assert!(matches!(aggregates[0].source, AggregateSource::Column(0)));
353 assert!(matches!(aggregates[1].op, AggregateOp::Avg));
354 assert!(matches!(aggregates[1].source, AggregateSource::Column(1)));
355 }
356
357 #[test]
358 fn test_extract_aggregates_unsupported() {
359 use crate::schema::CombinedSchema;
360 use vibesql_catalog::{ColumnSchema, TableSchema};
361 use vibesql_types::DataType;
362
363 let schema = TableSchema::new(
364 "test".to_string(),
365 vec![ColumnSchema::new("col1".to_string(), DataType::Integer, false)],
366 );
367
368 let combined_schema = CombinedSchema::from_table("test".to_string(), schema);
369
370 let exprs = vec![Expression::AggregateFunction {
372 name: "SUM".to_string(),
373 distinct: true,
374 args: vec![Expression::ColumnRef {
375 table: None,
376 column: "col1".to_string(),
377 }],
378 }];
379
380 let result = extract_aggregates(&exprs, &combined_schema);
381 assert!(result.is_none());
382
383 let exprs = vec![Expression::ColumnRef {
385 table: None,
386 column: "col1".to_string(),
387 }];
388
389 let result = extract_aggregates(&exprs, &combined_schema);
390 assert!(result.is_none());
391
392 let exprs = vec![Expression::AggregateFunction {
394 name: "SUM".to_string(),
395 distinct: false,
396 args: vec![Expression::ScalarSubquery(Box::new(vibesql_ast::SelectStmt {
397 with_clause: None,
398 distinct: false,
399 select_list: vec![],
400 into_table: None,
401 into_variables: None,
402 from: None,
403 where_clause: None,
404 group_by: None,
405 having: None,
406 order_by: None,
407 limit: None,
408 offset: None,
409 set_operation: None,
410 }))],
411 }];
412
413 let result = extract_aggregates(&exprs, &combined_schema);
414 assert!(result.is_none());
415 }
416
417 #[test]
418 fn test_extract_aggregates_with_expression() {
419 use crate::schema::CombinedSchema;
420 use vibesql_catalog::{ColumnSchema, TableSchema};
421 use vibesql_types::DataType;
422
423 let schema = TableSchema::new(
425 "test".to_string(),
426 vec![
427 ColumnSchema::new("price".to_string(), DataType::DoublePrecision, false),
428 ColumnSchema::new("discount".to_string(), DataType::DoublePrecision, false),
429 ],
430 );
431
432 let combined_schema = CombinedSchema::from_table("test".to_string(), schema);
433
434 let exprs = vec![Expression::AggregateFunction {
436 name: "SUM".to_string(),
437 distinct: false,
438 args: vec![Expression::BinaryOp {
439 left: Box::new(Expression::ColumnRef {
440 table: None,
441 column: "price".to_string(),
442 }),
443 op: vibesql_ast::BinaryOperator::Multiply,
444 right: Box::new(Expression::ColumnRef {
445 table: None,
446 column: "discount".to_string(),
447 }),
448 }],
449 }];
450
451 let result = extract_aggregates(&exprs, &combined_schema);
452 assert!(result.is_some());
453 let aggregates = result.unwrap();
454 assert_eq!(aggregates.len(), 1);
455 assert!(matches!(aggregates[0].op, AggregateOp::Sum));
456 assert!(matches!(aggregates[0].source, AggregateSource::Expression(_)));
457 }
458
459 #[test]
462 fn test_columnar_group_by_simple() {
463 let rows = vec![
466 Row::new(vec![SqlValue::Varchar("A".to_string()), SqlValue::Double(100.0)]),
467 Row::new(vec![SqlValue::Varchar("B".to_string()), SqlValue::Double(200.0)]),
468 Row::new(vec![SqlValue::Varchar("A".to_string()), SqlValue::Double(150.0)]),
469 Row::new(vec![SqlValue::Varchar("B".to_string()), SqlValue::Double(50.0)]),
470 ];
471
472 let group_cols = vec![0]; let agg_cols = vec![(1, AggregateOp::Sum)]; let result = columnar_group_by(&rows, &group_cols, &agg_cols, None).unwrap();
476
477 assert_eq!(result.len(), 2);
479
480 let mut sorted = result;
482 sorted.sort_by(|a, b| {
483 let a_key = a.get(0).unwrap();
484 let b_key = b.get(0).unwrap();
485 a_key.partial_cmp(b_key).unwrap()
486 });
487
488 assert_eq!(sorted[0].get(0), Some(&SqlValue::Varchar("A".to_string())));
490 assert!(matches!(sorted[0].get(1), Some(&SqlValue::Double(sum)) if (sum - 250.0).abs() < 0.001));
491
492 assert_eq!(sorted[1].get(0), Some(&SqlValue::Varchar("B".to_string())));
494 assert!(matches!(sorted[1].get(1), Some(&SqlValue::Double(sum)) if (sum - 250.0).abs() < 0.001));
495 }
496
497 #[test]
498 fn test_columnar_group_by_multiple_group_keys() {
499 let rows = vec![
502 Row::new(vec![SqlValue::Varchar("A".to_string()), SqlValue::Integer(1)]),
503 Row::new(vec![SqlValue::Varchar("A".to_string()), SqlValue::Integer(2)]),
504 Row::new(vec![SqlValue::Varchar("B".to_string()), SqlValue::Integer(1)]),
505 Row::new(vec![SqlValue::Varchar("A".to_string()), SqlValue::Integer(1)]),
506 Row::new(vec![SqlValue::Varchar("B".to_string()), SqlValue::Integer(2)]),
507 ];
508
509 let group_cols = vec![0, 1]; let agg_cols = vec![(0, AggregateOp::Count)]; let result = columnar_group_by(&rows, &group_cols, &agg_cols, None).unwrap();
513
514 assert_eq!(result.len(), 4);
516
517 for row in &result {
519 let status = row.get(0).unwrap();
520 let category = row.get(1).unwrap();
521 let count = row.get(2).unwrap();
522
523 match (status, category) {
524 (SqlValue::Varchar(s), SqlValue::Integer(1)) if s == "A" => {
525 assert_eq!(count, &SqlValue::Integer(2)); }
527 (SqlValue::Varchar(s), SqlValue::Integer(2)) if s == "A" => {
528 assert_eq!(count, &SqlValue::Integer(1)); }
530 (SqlValue::Varchar(s), SqlValue::Integer(1)) if s == "B" => {
531 assert_eq!(count, &SqlValue::Integer(1)); }
533 (SqlValue::Varchar(s), SqlValue::Integer(2)) if s == "B" => {
534 assert_eq!(count, &SqlValue::Integer(1)); }
536 _ => panic!("Unexpected group key: {:?}, {:?}", status, category),
537 }
538 }
539 }
540
541 #[test]
542 fn test_columnar_group_by_multiple_aggregates() {
543 let rows = vec![
546 Row::new(vec![SqlValue::Integer(1), SqlValue::Double(100.0), SqlValue::Integer(10)]),
547 Row::new(vec![SqlValue::Integer(2), SqlValue::Double(200.0), SqlValue::Integer(20)]),
548 Row::new(vec![SqlValue::Integer(1), SqlValue::Double(150.0), SqlValue::Integer(15)]),
549 ];
550
551 let group_cols = vec![0]; let agg_cols = vec![
553 (1, AggregateOp::Sum), (2, AggregateOp::Avg), (0, AggregateOp::Count), ];
557
558 let result = columnar_group_by(&rows, &group_cols, &agg_cols, None).unwrap();
559
560 assert_eq!(result.len(), 2);
562
563 let mut sorted = result;
565 sorted.sort_by(|a, b| {
566 let a_key = a.get(0).unwrap();
567 let b_key = b.get(0).unwrap();
568 a_key.partial_cmp(b_key).unwrap()
569 });
570
571 assert_eq!(sorted[0].get(0), Some(&SqlValue::Integer(1)));
573 assert!(matches!(sorted[0].get(1), Some(&SqlValue::Double(sum)) if (sum - 250.0).abs() < 0.001));
574 assert!(matches!(sorted[0].get(2), Some(&SqlValue::Double(avg)) if (avg - 12.5).abs() < 0.001));
575 assert_eq!(sorted[0].get(3), Some(&SqlValue::Integer(2)));
576
577 assert_eq!(sorted[1].get(0), Some(&SqlValue::Integer(2)));
579 assert!(matches!(sorted[1].get(1), Some(&SqlValue::Double(sum)) if (sum - 200.0).abs() < 0.001));
580 assert!(matches!(sorted[1].get(2), Some(&SqlValue::Double(avg)) if (avg - 20.0).abs() < 0.001));
581 assert_eq!(sorted[1].get(3), Some(&SqlValue::Integer(1)));
582 }
583
584 #[test]
585 fn test_columnar_group_by_with_filter() {
586 let rows = vec![
589 Row::new(vec![SqlValue::Varchar("A".to_string()), SqlValue::Double(100.0)]),
590 Row::new(vec![SqlValue::Varchar("B".to_string()), SqlValue::Double(200.0)]),
591 Row::new(vec![SqlValue::Varchar("A".to_string()), SqlValue::Double(150.0)]),
592 Row::new(vec![SqlValue::Varchar("B".to_string()), SqlValue::Double(50.0)]),
593 ];
594
595 let filter = vec![false, true, true, false];
597
598 let group_cols = vec![0]; let agg_cols = vec![(1, AggregateOp::Sum)]; let result = columnar_group_by(&rows, &group_cols, &agg_cols, Some(&filter)).unwrap();
602
603 assert_eq!(result.len(), 2);
605
606 let mut sorted = result;
608 sorted.sort_by(|a, b| {
609 let a_key = a.get(0).unwrap();
610 let b_key = b.get(0).unwrap();
611 a_key.partial_cmp(b_key).unwrap()
612 });
613
614 assert_eq!(sorted[0].get(0), Some(&SqlValue::Varchar("A".to_string())));
616 assert!(matches!(sorted[0].get(1), Some(&SqlValue::Double(sum)) if (sum - 150.0).abs() < 0.001));
617
618 assert_eq!(sorted[1].get(0), Some(&SqlValue::Varchar("B".to_string())));
620 assert!(matches!(sorted[1].get(1), Some(&SqlValue::Double(sum)) if (sum - 200.0).abs() < 0.001));
621 }
622
623 #[test]
624 fn test_columnar_group_by_empty_input() {
625 let rows: Vec<Row> = vec![];
626 let group_cols = vec![0];
627 let agg_cols = vec![(1, AggregateOp::Sum)];
628
629 let result = columnar_group_by(&rows, &group_cols, &agg_cols, None).unwrap();
630
631 assert_eq!(result.len(), 0);
633 }
634
635 #[test]
636 fn test_columnar_group_by_null_in_group_key() {
637 let rows = vec![
639 Row::new(vec![SqlValue::Varchar("A".to_string()), SqlValue::Double(100.0)]),
640 Row::new(vec![SqlValue::Null, SqlValue::Double(200.0)]),
641 Row::new(vec![SqlValue::Varchar("A".to_string()), SqlValue::Double(150.0)]),
642 Row::new(vec![SqlValue::Null, SqlValue::Double(50.0)]),
643 ];
644
645 let group_cols = vec![0]; let agg_cols = vec![(1, AggregateOp::Sum)]; let result = columnar_group_by(&rows, &group_cols, &agg_cols, None).unwrap();
649
650 assert_eq!(result.len(), 2);
652
653 let a_group = result.iter().find(|r| matches!(r.get(0), Some(SqlValue::Varchar(s)) if s == "A"));
655 let null_group = result.iter().find(|r| matches!(r.get(0), Some(SqlValue::Null)));
656
657 assert!(a_group.is_some());
658 assert!(null_group.is_some());
659
660 assert!(matches!(a_group.unwrap().get(1), Some(&SqlValue::Double(sum)) if (sum - 250.0).abs() < 0.001));
662
663 assert!(matches!(null_group.unwrap().get(1), Some(&SqlValue::Double(sum)) if (sum - 250.0).abs() < 0.001));
665 }
666}