1mod expression;
14mod functions;
15mod group_by;
16
17use crate::errors::ExecutorError;
18use crate::schema::CombinedSchema;
19use vibesql_ast::Expression;
20use vibesql_storage::Row;
21use vibesql_types::SqlValue;
22
23use super::batch::ColumnarBatch;
24use super::scan::ColumnarScan;
25
26pub use expression::evaluate_expression_to_column;
28pub use expression::evaluate_expression_with_cached_column;
29pub use expression::extract_aggregates;
30pub use group_by::columnar_group_by;
31pub use group_by::columnar_group_by_batch;
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub enum AggregateOp {
37 Sum,
38 Count,
39 Avg,
40 Min,
41 Max,
42}
43
44#[derive(Debug, Clone)]
46pub enum AggregateSource {
47 Column(usize),
49 Expression(Expression),
51 CountStar,
53}
54
55#[derive(Debug, Clone)]
57pub struct AggregateSpec {
58 pub op: AggregateOp,
59 pub source: AggregateSource,
60}
61
62pub fn compute_columnar_aggregate(
81 scan: &ColumnarScan,
82 column_idx: usize,
83 op: AggregateOp,
84 filter_bitmap: Option<&[bool]>,
85) -> Result<SqlValue, ExecutorError> {
86 {
88 use super::simd_aggregate::{
89 can_use_simd_for_column, simd_aggregate_f64, simd_aggregate_i64,
90 };
91
92 if let Some(is_integer) = can_use_simd_for_column(scan, column_idx) {
94 return if is_integer {
96 simd_aggregate_i64(scan, column_idx, op, filter_bitmap)
97 } else {
98 simd_aggregate_f64(scan, column_idx, op, filter_bitmap)
99 };
100 }
101 }
103
104 functions::compute_columnar_aggregate_impl(scan, column_idx, op, filter_bitmap)
106}
107
108pub fn compute_multiple_aggregates(
113 rows: &[Row],
114 aggregates: &[AggregateSpec],
115 filter_bitmap: Option<&[bool]>,
116 schema: Option<&CombinedSchema>,
117) -> Result<Vec<SqlValue>, ExecutorError> {
118 let scan = ColumnarScan::new(rows);
119 let mut results = Vec::with_capacity(aggregates.len());
120
121 for spec in aggregates {
122 let result = match &spec.source {
123 AggregateSource::Column(column_idx) => {
125 compute_columnar_aggregate(&scan, *column_idx, spec.op, filter_bitmap)?
126 }
127 AggregateSource::Expression(expr) => {
129 let schema = schema.ok_or_else(|| {
130 ExecutorError::UnsupportedExpression(
131 "Schema required for expression aggregates".to_string(),
132 )
133 })?;
134 expression::compute_expression_aggregate(
135 rows,
136 expr,
137 spec.op,
138 filter_bitmap,
139 schema,
140 )?
141 }
142 AggregateSource::CountStar => functions::compute_count(&scan, filter_bitmap)?,
144 };
145 results.push(result);
146 }
147
148 Ok(results)
149}
150
151pub fn compute_aggregates_from_batch(
174 batch: &ColumnarBatch,
175 aggregates: &[AggregateSpec],
176 schema: Option<&CombinedSchema>,
177) -> Result<Vec<SqlValue>, ExecutorError> {
178 if batch.row_count() == 0 {
180 return Ok(aggregates
181 .iter()
182 .map(|spec| match spec.op {
183 AggregateOp::Count => SqlValue::Integer(0),
184 _ => SqlValue::Null,
185 })
186 .collect());
187 }
188
189 let mut results = Vec::with_capacity(aggregates.len());
190
191 for spec in aggregates {
192 let result = match &spec.source {
193 AggregateSource::Column(column_idx) => {
195 functions::compute_batch_aggregate(batch, *column_idx, spec.op)?
196 }
197 AggregateSource::Expression(expr) => {
200 let schema = schema.ok_or_else(|| {
201 ExecutorError::UnsupportedExpression(
202 "Schema required for expression aggregates".to_string(),
203 )
204 })?;
205 match expression::compute_batch_expression_aggregate(batch, expr, spec.op, schema) {
208 Ok(value) => value,
209 Err(ExecutorError::UnsupportedExpression(_)) => {
210 let rows = batch.to_rows()?;
212 expression::compute_expression_aggregate(
213 &rows, expr, spec.op, None, schema,
214 )?
215 }
216 Err(other) => return Err(other),
217 }
218 }
219 AggregateSource::CountStar => SqlValue::Integer(batch.row_count() as i64),
221 };
222 results.push(result);
223 }
224
225 Ok(results)
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231
232 fn make_test_rows() -> Vec<Row> {
233 vec![
234 Row::new(vec![SqlValue::Integer(10), SqlValue::Double(1.5)]),
235 Row::new(vec![SqlValue::Integer(20), SqlValue::Double(2.5)]),
236 Row::new(vec![SqlValue::Integer(30), SqlValue::Double(3.5)]),
237 ]
238 }
239
240 #[test]
241 fn test_sum_aggregate() {
242 let rows = make_test_rows();
243 let scan = ColumnarScan::new(&rows);
244
245 let result = functions::compute_sum(&scan, 0, None).unwrap();
246 assert_eq!(result, SqlValue::Integer(60));
247
248 let result = functions::compute_sum(&scan, 1, None).unwrap();
249 assert!(matches!(result, SqlValue::Double(sum) if (sum - 7.5).abs() < 0.001));
250 }
251
252 #[test]
253 fn test_count_aggregate() {
254 let rows = make_test_rows();
255 let scan = ColumnarScan::new(&rows);
256
257 let result = functions::compute_count(&scan, None).unwrap();
258 assert_eq!(result, SqlValue::Integer(3));
259 }
260
261 #[test]
262 fn test_sum_with_filter() {
263 let rows = make_test_rows();
264 let scan = ColumnarScan::new(&rows);
265 let filter = vec![true, false, true]; let result = functions::compute_sum(&scan, 0, Some(&filter)).unwrap();
268 assert_eq!(result, SqlValue::Integer(40));
269 }
270
271 #[test]
272 fn test_multiple_aggregates() {
273 let rows = make_test_rows();
274 let aggregates = vec![
275 AggregateSpec { op: AggregateOp::Sum, source: AggregateSource::Column(0) },
276 AggregateSpec { op: AggregateOp::Avg, source: AggregateSource::Column(1) },
277 ];
278
279 let results = compute_multiple_aggregates(&rows, &aggregates, None, None).unwrap();
280 assert_eq!(results.len(), 2);
281 assert_eq!(results[0], SqlValue::Integer(60));
282 assert!(matches!(results[1], SqlValue::Double(avg) if (avg - 2.5).abs() < 0.001));
283 }
284
285 #[test]
286 fn test_extract_aggregates_simple() {
287 use crate::schema::CombinedSchema;
288 use vibesql_catalog::{ColumnSchema, TableSchema};
289 use vibesql_types::DataType;
290
291 let schema = TableSchema::new(
293 "test".to_string(),
294 vec![
295 ColumnSchema::new("col1".to_string(), DataType::Integer, false),
296 ColumnSchema::new("col2".to_string(), DataType::DoublePrecision, false),
297 ],
298 );
299
300 let combined_schema = CombinedSchema::from_table("test".to_string(), schema);
301
302 let exprs = vec![Expression::AggregateFunction {
304 name: "SUM".to_string(),
305 distinct: false,
306 args: vec![Expression::ColumnRef { table: None, column: "col1".to_string() }],
307 }];
308
309 let result = extract_aggregates(&exprs, &combined_schema);
310 assert!(result.is_some());
311 let aggregates = result.unwrap();
312 assert_eq!(aggregates.len(), 1);
313 assert!(matches!(aggregates[0].op, AggregateOp::Sum));
314 assert!(matches!(aggregates[0].source, AggregateSource::Column(0)));
315
316 let exprs = vec![Expression::AggregateFunction {
318 name: "COUNT".to_string(),
319 distinct: false,
320 args: vec![Expression::Wildcard],
321 }];
322
323 let result = extract_aggregates(&exprs, &combined_schema);
324 assert!(result.is_some());
325 let aggregates = result.unwrap();
326 assert_eq!(aggregates.len(), 1);
327 assert!(matches!(aggregates[0].op, AggregateOp::Count));
328 assert!(matches!(aggregates[0].source, AggregateSource::CountStar));
329
330 let exprs = vec![
332 Expression::AggregateFunction {
333 name: "SUM".to_string(),
334 distinct: false,
335 args: vec![Expression::ColumnRef { table: None, column: "col1".to_string() }],
336 },
337 Expression::AggregateFunction {
338 name: "AVG".to_string(),
339 distinct: false,
340 args: vec![Expression::ColumnRef { table: None, column: "col2".to_string() }],
341 },
342 ];
343
344 let result = extract_aggregates(&exprs, &combined_schema);
345 assert!(result.is_some());
346 let aggregates = result.unwrap();
347 assert_eq!(aggregates.len(), 2);
348 assert!(matches!(aggregates[0].op, AggregateOp::Sum));
349 assert!(matches!(aggregates[0].source, AggregateSource::Column(0)));
350 assert!(matches!(aggregates[1].op, AggregateOp::Avg));
351 assert!(matches!(aggregates[1].source, AggregateSource::Column(1)));
352 }
353
354 #[test]
355 fn test_extract_aggregates_unsupported() {
356 use crate::schema::CombinedSchema;
357 use vibesql_catalog::{ColumnSchema, TableSchema};
358 use vibesql_types::DataType;
359
360 let schema = TableSchema::new(
361 "test".to_string(),
362 vec![ColumnSchema::new("col1".to_string(), DataType::Integer, false)],
363 );
364
365 let combined_schema = CombinedSchema::from_table("test".to_string(), schema);
366
367 let exprs = vec![Expression::AggregateFunction {
369 name: "SUM".to_string(),
370 distinct: true,
371 args: vec![Expression::ColumnRef { table: None, column: "col1".to_string() }],
372 }];
373
374 let result = extract_aggregates(&exprs, &combined_schema);
375 assert!(result.is_none());
376
377 let exprs = vec![Expression::ColumnRef { table: None, column: "col1".to_string() }];
379
380 let result = extract_aggregates(&exprs, &combined_schema);
381 assert!(result.is_none());
382
383 let exprs = vec![Expression::AggregateFunction {
385 name: "SUM".to_string(),
386 distinct: false,
387 args: vec![Expression::ScalarSubquery(Box::new(vibesql_ast::SelectStmt {
388 with_clause: None,
389 distinct: false,
390 select_list: vec![],
391 into_table: None,
392 into_variables: None,
393 from: None,
394 where_clause: None,
395 group_by: None,
396 having: None,
397 order_by: None,
398 limit: None,
399 offset: None,
400 set_operation: None,
401 }))],
402 }];
403
404 let result = extract_aggregates(&exprs, &combined_schema);
405 assert!(result.is_none());
406 }
407
408 #[test]
409 fn test_extract_aggregates_with_expression() {
410 use crate::schema::CombinedSchema;
411 use vibesql_catalog::{ColumnSchema, TableSchema};
412 use vibesql_types::DataType;
413
414 let schema = TableSchema::new(
416 "test".to_string(),
417 vec![
418 ColumnSchema::new("price".to_string(), DataType::DoublePrecision, false),
419 ColumnSchema::new("discount".to_string(), DataType::DoublePrecision, false),
420 ],
421 );
422
423 let combined_schema = CombinedSchema::from_table("test".to_string(), schema);
424
425 let exprs = vec![Expression::AggregateFunction {
427 name: "SUM".to_string(),
428 distinct: false,
429 args: vec![Expression::BinaryOp {
430 left: Box::new(Expression::ColumnRef { table: None, column: "price".to_string() }),
431 op: vibesql_ast::BinaryOperator::Multiply,
432 right: Box::new(Expression::ColumnRef {
433 table: None,
434 column: "discount".to_string(),
435 }),
436 }],
437 }];
438
439 let result = extract_aggregates(&exprs, &combined_schema);
440 assert!(result.is_some());
441 let aggregates = result.unwrap();
442 assert_eq!(aggregates.len(), 1);
443 assert!(matches!(aggregates[0].op, AggregateOp::Sum));
444 assert!(matches!(aggregates[0].source, AggregateSource::Expression(_)));
445 }
446
447 #[test]
450 fn test_columnar_group_by_simple() {
451 let rows = vec![
454 Row::new(vec![SqlValue::Varchar("A".to_string()), SqlValue::Double(100.0)]),
455 Row::new(vec![SqlValue::Varchar("B".to_string()), SqlValue::Double(200.0)]),
456 Row::new(vec![SqlValue::Varchar("A".to_string()), SqlValue::Double(150.0)]),
457 Row::new(vec![SqlValue::Varchar("B".to_string()), SqlValue::Double(50.0)]),
458 ];
459
460 let group_cols = vec![0]; let agg_cols = vec![(1, AggregateOp::Sum)]; let result = columnar_group_by(&rows, &group_cols, &agg_cols, None).unwrap();
464
465 assert_eq!(result.len(), 2);
467
468 let mut sorted = result;
470 sorted.sort_by(|a, b| {
471 let a_key = a.get(0).unwrap();
472 let b_key = b.get(0).unwrap();
473 a_key.partial_cmp(b_key).unwrap()
474 });
475
476 assert_eq!(sorted[0].get(0), Some(&SqlValue::Varchar("A".to_string())));
478 assert!(
479 matches!(sorted[0].get(1), Some(&SqlValue::Double(sum)) if (sum - 250.0).abs() < 0.001)
480 );
481
482 assert_eq!(sorted[1].get(0), Some(&SqlValue::Varchar("B".to_string())));
484 assert!(
485 matches!(sorted[1].get(1), Some(&SqlValue::Double(sum)) if (sum - 250.0).abs() < 0.001)
486 );
487 }
488
489 #[test]
490 fn test_columnar_group_by_multiple_group_keys() {
491 let rows = vec![
494 Row::new(vec![SqlValue::Varchar("A".to_string()), SqlValue::Integer(1)]),
495 Row::new(vec![SqlValue::Varchar("A".to_string()), SqlValue::Integer(2)]),
496 Row::new(vec![SqlValue::Varchar("B".to_string()), SqlValue::Integer(1)]),
497 Row::new(vec![SqlValue::Varchar("A".to_string()), SqlValue::Integer(1)]),
498 Row::new(vec![SqlValue::Varchar("B".to_string()), SqlValue::Integer(2)]),
499 ];
500
501 let group_cols = vec![0, 1]; let agg_cols = vec![(0, AggregateOp::Count)]; let result = columnar_group_by(&rows, &group_cols, &agg_cols, None).unwrap();
505
506 assert_eq!(result.len(), 4);
508
509 for row in &result {
511 let status = row.get(0).unwrap();
512 let category = row.get(1).unwrap();
513 let count = row.get(2).unwrap();
514
515 match (status, category) {
516 (SqlValue::Varchar(s), SqlValue::Integer(1)) if s == "A" => {
517 assert_eq!(count, &SqlValue::Integer(2)); }
519 (SqlValue::Varchar(s), SqlValue::Integer(2)) if s == "A" => {
520 assert_eq!(count, &SqlValue::Integer(1)); }
522 (SqlValue::Varchar(s), SqlValue::Integer(1)) if s == "B" => {
523 assert_eq!(count, &SqlValue::Integer(1)); }
525 (SqlValue::Varchar(s), SqlValue::Integer(2)) if s == "B" => {
526 assert_eq!(count, &SqlValue::Integer(1)); }
528 _ => panic!("Unexpected group key: {:?}, {:?}", status, category),
529 }
530 }
531 }
532
533 #[test]
534 fn test_columnar_group_by_multiple_aggregates() {
535 let rows = vec![
538 Row::new(vec![SqlValue::Integer(1), SqlValue::Double(100.0), SqlValue::Integer(10)]),
539 Row::new(vec![SqlValue::Integer(2), SqlValue::Double(200.0), SqlValue::Integer(20)]),
540 Row::new(vec![SqlValue::Integer(1), SqlValue::Double(150.0), SqlValue::Integer(15)]),
541 ];
542
543 let group_cols = vec![0]; let agg_cols = vec![
545 (1, AggregateOp::Sum), (2, AggregateOp::Avg), (0, AggregateOp::Count), ];
549
550 let result = columnar_group_by(&rows, &group_cols, &agg_cols, None).unwrap();
551
552 assert_eq!(result.len(), 2);
554
555 let mut sorted = result;
557 sorted.sort_by(|a, b| {
558 let a_key = a.get(0).unwrap();
559 let b_key = b.get(0).unwrap();
560 a_key.partial_cmp(b_key).unwrap()
561 });
562
563 assert_eq!(sorted[0].get(0), Some(&SqlValue::Integer(1)));
565 assert!(
566 matches!(sorted[0].get(1), Some(&SqlValue::Double(sum)) if (sum - 250.0).abs() < 0.001)
567 );
568 assert!(
569 matches!(sorted[0].get(2), Some(&SqlValue::Double(avg)) if (avg - 12.5).abs() < 0.001)
570 );
571 assert_eq!(sorted[0].get(3), Some(&SqlValue::Integer(2)));
572
573 assert_eq!(sorted[1].get(0), Some(&SqlValue::Integer(2)));
575 assert!(
576 matches!(sorted[1].get(1), Some(&SqlValue::Double(sum)) if (sum - 200.0).abs() < 0.001)
577 );
578 assert!(
579 matches!(sorted[1].get(2), Some(&SqlValue::Double(avg)) if (avg - 20.0).abs() < 0.001)
580 );
581 assert_eq!(sorted[1].get(3), Some(&SqlValue::Integer(1)));
582 }
583
584 #[test]
585 fn test_columnar_group_by_with_filter() {
586 let rows = vec![
589 Row::new(vec![SqlValue::Varchar("A".to_string()), SqlValue::Double(100.0)]),
590 Row::new(vec![SqlValue::Varchar("B".to_string()), SqlValue::Double(200.0)]),
591 Row::new(vec![SqlValue::Varchar("A".to_string()), SqlValue::Double(150.0)]),
592 Row::new(vec![SqlValue::Varchar("B".to_string()), SqlValue::Double(50.0)]),
593 ];
594
595 let filter = vec![false, true, true, false];
597
598 let group_cols = vec![0]; let agg_cols = vec![(1, AggregateOp::Sum)]; let result = columnar_group_by(&rows, &group_cols, &agg_cols, Some(&filter)).unwrap();
602
603 assert_eq!(result.len(), 2);
605
606 let mut sorted = result;
608 sorted.sort_by(|a, b| {
609 let a_key = a.get(0).unwrap();
610 let b_key = b.get(0).unwrap();
611 a_key.partial_cmp(b_key).unwrap()
612 });
613
614 assert_eq!(sorted[0].get(0), Some(&SqlValue::Varchar("A".to_string())));
616 assert!(
617 matches!(sorted[0].get(1), Some(&SqlValue::Double(sum)) if (sum - 150.0).abs() < 0.001)
618 );
619
620 assert_eq!(sorted[1].get(0), Some(&SqlValue::Varchar("B".to_string())));
622 assert!(
623 matches!(sorted[1].get(1), Some(&SqlValue::Double(sum)) if (sum - 200.0).abs() < 0.001)
624 );
625 }
626
627 #[test]
628 fn test_columnar_group_by_empty_input() {
629 let rows: Vec<Row> = vec![];
630 let group_cols = vec![0];
631 let agg_cols = vec![(1, AggregateOp::Sum)];
632
633 let result = columnar_group_by(&rows, &group_cols, &agg_cols, None).unwrap();
634
635 assert_eq!(result.len(), 0);
637 }
638
639 #[test]
640 fn test_columnar_group_by_null_in_group_key() {
641 let rows = vec![
643 Row::new(vec![SqlValue::Varchar("A".to_string()), SqlValue::Double(100.0)]),
644 Row::new(vec![SqlValue::Null, SqlValue::Double(200.0)]),
645 Row::new(vec![SqlValue::Varchar("A".to_string()), SqlValue::Double(150.0)]),
646 Row::new(vec![SqlValue::Null, SqlValue::Double(50.0)]),
647 ];
648
649 let group_cols = vec![0]; let agg_cols = vec![(1, AggregateOp::Sum)]; let result = columnar_group_by(&rows, &group_cols, &agg_cols, None).unwrap();
653
654 assert_eq!(result.len(), 2);
656
657 let a_group =
659 result.iter().find(|r| matches!(r.get(0), Some(SqlValue::Varchar(s)) if s == "A"));
660 let null_group = result.iter().find(|r| matches!(r.get(0), Some(SqlValue::Null)));
661
662 assert!(a_group.is_some());
663 assert!(null_group.is_some());
664
665 assert!(
667 matches!(a_group.unwrap().get(1), Some(&SqlValue::Double(sum)) if (sum - 250.0).abs() < 0.001)
668 );
669
670 assert!(
672 matches!(null_group.unwrap().get(1), Some(&SqlValue::Double(sum)) if (sum - 250.0).abs() < 0.001)
673 );
674 }
675}