vibesql_executor/select/columnar/aggregate/
mod.rs

1//! Columnar aggregation - high-performance aggregate computation
2//!
3//! This module provides efficient aggregate operations on columnar data,
4//! organized into several sub-modules:
5//!
6//! - [`functions`] - Core aggregate function implementations (SUM, COUNT, AVG, MIN, MAX)
7//! - [`expression`] - Expression aggregates (e.g., SUM(a * b))
8//! - [`group_by`] - Hash-based GROUP BY aggregation
9//!
10//! The public API maintains backward compatibility while the implementation
11//! is split across focused modules for better maintainability.
12
13mod expression;
14mod functions;
15mod group_by;
16
17use crate::errors::ExecutorError;
18use crate::schema::CombinedSchema;
19use vibesql_ast::Expression;
20use vibesql_storage::Row;
21use vibesql_types::SqlValue;
22
23use super::batch::ColumnarBatch;
24use super::scan::ColumnarScan;
25
26// Re-export public types and functions to maintain API compatibility
27pub use expression::extract_aggregates;
28pub use expression::evaluate_expression_to_column;
29pub use expression::evaluate_expression_with_cached_column;
30pub use group_by::columnar_group_by;
31// SIMD-accelerated GROUP BY for ColumnarBatch - used in native columnar execution path
32pub use group_by::columnar_group_by_batch;
33
34/// Aggregate operation type
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub enum AggregateOp {
37    Sum,
38    Count,
39    Avg,
40    Min,
41    Max,
42}
43
44/// Source of data for an aggregate - either a simple column or an expression
45#[derive(Debug, Clone)]
46pub enum AggregateSource {
47    /// Simple column reference (fast path) - just a column index
48    Column(usize),
49    /// Complex expression that needs evaluation (e.g., a * b)
50    Expression(Expression),
51    /// COUNT(*) - count all rows (not tied to any specific column)
52    CountStar,
53}
54
55/// A complete aggregate specification
56#[derive(Debug, Clone)]
57pub struct AggregateSpec {
58    pub op: AggregateOp,
59    pub source: AggregateSource,
60}
61
62/// Compute an aggregate over a column with optional filtering
63///
64/// This is the core columnar aggregation function that processes
65/// columns directly without materializing Row objects.
66///
67/// Automatically detects when SIMD optimization is available (Int64/Float64 columns)
68/// and falls back to scalar implementation for other types.
69///
70/// # Arguments
71///
72/// * `scan` - Columnar scan over the data
73/// * `column_idx` - Index of the column to aggregate
74/// * `op` - Aggregate operation (SUM, COUNT, etc.)
75/// * `filter_bitmap` - Optional bitmap of which rows to include
76///
77/// # Returns
78///
79/// The aggregated SqlValue
80pub fn compute_columnar_aggregate(
81    scan: &ColumnarScan,
82    column_idx: usize,
83    op: AggregateOp,
84    filter_bitmap: Option<&[bool]>,
85) -> Result<SqlValue, ExecutorError> {
86    // Try SIMD path for numeric columns (5-10x speedup via LLVM auto-vectorization)
87    {
88        use super::simd_aggregate::{can_use_simd_for_column, simd_aggregate_f64, simd_aggregate_i64};
89
90        // Detect if column is SIMD-compatible
91        if let Some(is_integer) = can_use_simd_for_column(scan, column_idx) {
92            // Use SIMD implementation for Int64/Float64 columns
93            return if is_integer {
94                simd_aggregate_i64(scan, column_idx, op, filter_bitmap)
95            } else {
96                simd_aggregate_f64(scan, column_idx, op, filter_bitmap)
97            };
98        }
99        // Fall through to scalar path for non-SIMD types
100    }
101
102    // Scalar fallback path (always available, used for String, Date, etc.)
103    functions::compute_columnar_aggregate_impl(scan, column_idx, op, filter_bitmap)
104}
105
106/// Compute multiple aggregates in a single pass over the data
107///
108/// This is more efficient than computing each aggregate separately
109/// as it only scans the data once.
110pub fn compute_multiple_aggregates(
111    rows: &[Row],
112    aggregates: &[AggregateSpec],
113    filter_bitmap: Option<&[bool]>,
114    schema: Option<&CombinedSchema>,
115) -> Result<Vec<SqlValue>, ExecutorError> {
116    let scan = ColumnarScan::new(rows);
117    let mut results = Vec::with_capacity(aggregates.len());
118
119    for spec in aggregates {
120        let result = match &spec.source {
121            // Fast path: direct column aggregation
122            AggregateSource::Column(column_idx) => {
123                compute_columnar_aggregate(&scan, *column_idx, spec.op, filter_bitmap)?
124            }
125            // Expression path: evaluate expression for each row, then aggregate
126            AggregateSource::Expression(expr) => {
127                let schema = schema.ok_or_else(|| {
128                    ExecutorError::UnsupportedExpression(
129                        "Schema required for expression aggregates".to_string()
130                    )
131                })?;
132                expression::compute_expression_aggregate(rows, expr, spec.op, filter_bitmap, schema)?
133            }
134            // COUNT(*) path: count all rows
135            AggregateSource::CountStar => {
136                functions::compute_count(&scan, filter_bitmap)?
137            }
138        };
139        results.push(result);
140    }
141
142    Ok(results)
143}
144
145/// Compute aggregates directly from a ColumnarBatch (no row conversion)
146///
147/// This is a high-performance path that eliminates the overhead of converting
148/// ColumnarBatch back to rows before aggregation. It works directly on the
149/// typed column arrays for maximum efficiency.
150///
151/// # Arguments
152///
153/// * `batch` - The ColumnarBatch to aggregate over (typically filtered)
154/// * `aggregates` - List of aggregate specifications to compute
155/// * `schema` - Optional schema for expression aggregates
156///
157/// # Performance
158///
159/// This function provides 20-30% speedup over the row-based path by:
160/// - Avoiding `batch.to_rows()` conversion overhead (~10-15ms for large batches)
161/// - Working directly on typed arrays (Int64, Float64) with SIMD operations
162/// - Reducing memory allocations
163///
164/// # Returns
165///
166/// Vector of aggregate results in the same order as the input aggregates
167pub fn compute_aggregates_from_batch(
168    batch: &ColumnarBatch,
169    aggregates: &[AggregateSpec],
170    schema: Option<&CombinedSchema>,
171) -> Result<Vec<SqlValue>, ExecutorError> {
172    // Handle empty batch
173    if batch.row_count() == 0 {
174        return Ok(aggregates
175            .iter()
176            .map(|spec| match spec.op {
177                AggregateOp::Count => SqlValue::Integer(0),
178                _ => SqlValue::Null,
179            })
180            .collect());
181    }
182
183    let mut results = Vec::with_capacity(aggregates.len());
184
185    for spec in aggregates {
186        let result = match &spec.source {
187            // Fast path: direct batch aggregation (no row conversion)
188            AggregateSource::Column(column_idx) => {
189                functions::compute_batch_aggregate(batch, *column_idx, spec.op)?
190            }
191            // Expression path: SIMD-accelerated evaluation directly on batch columns
192            // This eliminates the ~10-15ms overhead of batch.to_rows() for large batches
193            AggregateSource::Expression(expr) => {
194                let schema = schema.ok_or_else(|| {
195                    ExecutorError::UnsupportedExpression(
196                        "Schema required for expression aggregates".to_string()
197                    )
198                })?;
199                // Try batch-native expression evaluation first
200                // Falls back to row-based evaluation if column types aren't supported
201                match expression::compute_batch_expression_aggregate(batch, expr, spec.op, schema) {
202                    Ok(value) => value,
203                    Err(ExecutorError::UnsupportedExpression(_)) => {
204                        // Fall back to row-based for unsupported column types (Mixed, Date, etc.)
205                        let rows = batch.to_rows()?;
206                        expression::compute_expression_aggregate(&rows, expr, spec.op, None, schema)?
207                    }
208                    Err(other) => return Err(other),
209                }
210            }
211            // COUNT(*) path: just count rows in batch
212            AggregateSource::CountStar => {
213                SqlValue::Integer(batch.row_count() as i64)
214            }
215        };
216        results.push(result);
217    }
218
219    Ok(results)
220}
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225
226    fn make_test_rows() -> Vec<Row> {
227        vec![
228            Row::new(vec![SqlValue::Integer(10), SqlValue::Double(1.5)]),
229            Row::new(vec![SqlValue::Integer(20), SqlValue::Double(2.5)]),
230            Row::new(vec![SqlValue::Integer(30), SqlValue::Double(3.5)]),
231        ]
232    }
233
234    #[test]
235    fn test_sum_aggregate() {
236        let rows = make_test_rows();
237        let scan = ColumnarScan::new(&rows);
238
239        let result = functions::compute_sum(&scan, 0, None).unwrap();
240        assert_eq!(result, SqlValue::Integer(60));
241
242        let result = functions::compute_sum(&scan, 1, None).unwrap();
243        assert!(matches!(result, SqlValue::Double(sum) if (sum - 7.5).abs() < 0.001));
244    }
245
246    #[test]
247    fn test_count_aggregate() {
248        let rows = make_test_rows();
249        let scan = ColumnarScan::new(&rows);
250
251        let result = functions::compute_count(&scan, None).unwrap();
252        assert_eq!(result, SqlValue::Integer(3));
253    }
254
255    #[test]
256    fn test_sum_with_filter() {
257        let rows = make_test_rows();
258        let scan = ColumnarScan::new(&rows);
259        let filter = vec![true, false, true]; // Include rows 0 and 2
260
261        let result = functions::compute_sum(&scan, 0, Some(&filter)).unwrap();
262        assert_eq!(result, SqlValue::Integer(40));
263    }
264
265    #[test]
266    fn test_multiple_aggregates() {
267        let rows = make_test_rows();
268        let aggregates = vec![
269            AggregateSpec { op: AggregateOp::Sum, source: AggregateSource::Column(0) },
270            AggregateSpec { op: AggregateOp::Avg, source: AggregateSource::Column(1) },
271        ];
272
273        let results = compute_multiple_aggregates(&rows, &aggregates, None, None).unwrap();
274        assert_eq!(results.len(), 2);
275        assert_eq!(results[0], SqlValue::Integer(60));
276        assert!(matches!(results[1], SqlValue::Double(avg) if (avg - 2.5).abs() < 0.001));
277    }
278
279    #[test]
280    fn test_extract_aggregates_simple() {
281        use crate::schema::CombinedSchema;
282        use vibesql_catalog::{ColumnSchema, TableSchema};
283        use vibesql_types::DataType;
284
285        // Create a simple schema with two columns
286        let schema = TableSchema::new(
287            "test".to_string(),
288            vec![
289                ColumnSchema::new("col1".to_string(), DataType::Integer, false),
290                ColumnSchema::new("col2".to_string(), DataType::DoublePrecision, false),
291            ],
292        );
293
294        let combined_schema = CombinedSchema::from_table("test".to_string(), schema);
295
296        // Test SUM(col1)
297        let exprs = vec![Expression::AggregateFunction {
298            name: "SUM".to_string(),
299            distinct: false,
300            args: vec![Expression::ColumnRef {
301                table: None,
302                column: "col1".to_string(),
303            }],
304        }];
305
306        let result = extract_aggregates(&exprs, &combined_schema);
307        assert!(result.is_some());
308        let aggregates = result.unwrap();
309        assert_eq!(aggregates.len(), 1);
310        assert!(matches!(aggregates[0].op, AggregateOp::Sum));
311        assert!(matches!(aggregates[0].source, AggregateSource::Column(0)));
312
313        // Test COUNT(*)
314        let exprs = vec![Expression::AggregateFunction {
315            name: "COUNT".to_string(),
316            distinct: false,
317            args: vec![Expression::Wildcard],
318        }];
319
320        let result = extract_aggregates(&exprs, &combined_schema);
321        assert!(result.is_some());
322        let aggregates = result.unwrap();
323        assert_eq!(aggregates.len(), 1);
324        assert!(matches!(aggregates[0].op, AggregateOp::Count));
325        assert!(matches!(aggregates[0].source, AggregateSource::CountStar));
326
327        // Test multiple aggregates: SUM(col1), AVG(col2)
328        let exprs = vec![
329            Expression::AggregateFunction {
330                name: "SUM".to_string(),
331                distinct: false,
332                args: vec![Expression::ColumnRef {
333                    table: None,
334                    column: "col1".to_string(),
335                }],
336            },
337            Expression::AggregateFunction {
338                name: "AVG".to_string(),
339                distinct: false,
340                args: vec![Expression::ColumnRef {
341                    table: None,
342                    column: "col2".to_string(),
343                }],
344            },
345        ];
346
347        let result = extract_aggregates(&exprs, &combined_schema);
348        assert!(result.is_some());
349        let aggregates = result.unwrap();
350        assert_eq!(aggregates.len(), 2);
351        assert!(matches!(aggregates[0].op, AggregateOp::Sum));
352        assert!(matches!(aggregates[0].source, AggregateSource::Column(0)));
353        assert!(matches!(aggregates[1].op, AggregateOp::Avg));
354        assert!(matches!(aggregates[1].source, AggregateSource::Column(1)));
355    }
356
357    #[test]
358    fn test_extract_aggregates_unsupported() {
359        use crate::schema::CombinedSchema;
360        use vibesql_catalog::{ColumnSchema, TableSchema};
361        use vibesql_types::DataType;
362
363        let schema = TableSchema::new(
364            "test".to_string(),
365            vec![ColumnSchema::new("col1".to_string(), DataType::Integer, false)],
366        );
367
368        let combined_schema = CombinedSchema::from_table("test".to_string(), schema);
369
370        // Test DISTINCT aggregate (should return None)
371        let exprs = vec![Expression::AggregateFunction {
372            name: "SUM".to_string(),
373            distinct: true,
374            args: vec![Expression::ColumnRef {
375                table: None,
376                column: "col1".to_string(),
377            }],
378        }];
379
380        let result = extract_aggregates(&exprs, &combined_schema);
381        assert!(result.is_none());
382
383        // Test non-aggregate expression (should return None)
384        let exprs = vec![Expression::ColumnRef {
385            table: None,
386            column: "col1".to_string(),
387        }];
388
389        let result = extract_aggregates(&exprs, &combined_schema);
390        assert!(result.is_none());
391
392        // Test subquery in aggregate (should return None - not supported)
393        let exprs = vec![Expression::AggregateFunction {
394            name: "SUM".to_string(),
395            distinct: false,
396            args: vec![Expression::ScalarSubquery(Box::new(vibesql_ast::SelectStmt {
397                with_clause: None,
398                distinct: false,
399                select_list: vec![],
400                into_table: None,
401                into_variables: None,
402                from: None,
403                where_clause: None,
404                group_by: None,
405                having: None,
406                order_by: None,
407                limit: None,
408                offset: None,
409                set_operation: None,
410            }))],
411        }];
412
413        let result = extract_aggregates(&exprs, &combined_schema);
414        assert!(result.is_none());
415    }
416
417    #[test]
418    fn test_extract_aggregates_with_expression() {
419        use crate::schema::CombinedSchema;
420        use vibesql_catalog::{ColumnSchema, TableSchema};
421        use vibesql_types::DataType;
422
423        // Create a simple schema with two columns
424        let schema = TableSchema::new(
425            "test".to_string(),
426            vec![
427                ColumnSchema::new("price".to_string(), DataType::DoublePrecision, false),
428                ColumnSchema::new("discount".to_string(), DataType::DoublePrecision, false),
429            ],
430        );
431
432        let combined_schema = CombinedSchema::from_table("test".to_string(), schema);
433
434        // Test SUM(price * discount) - simple binary operation
435        let exprs = vec![Expression::AggregateFunction {
436            name: "SUM".to_string(),
437            distinct: false,
438            args: vec![Expression::BinaryOp {
439                left: Box::new(Expression::ColumnRef {
440                    table: None,
441                    column: "price".to_string(),
442                }),
443                op: vibesql_ast::BinaryOperator::Multiply,
444                right: Box::new(Expression::ColumnRef {
445                    table: None,
446                    column: "discount".to_string(),
447                }),
448            }],
449        }];
450
451        let result = extract_aggregates(&exprs, &combined_schema);
452        assert!(result.is_some());
453        let aggregates = result.unwrap();
454        assert_eq!(aggregates.len(), 1);
455        assert!(matches!(aggregates[0].op, AggregateOp::Sum));
456        assert!(matches!(aggregates[0].source, AggregateSource::Expression(_)));
457    }
458
459    // GROUP BY tests
460
461    #[test]
462    fn test_columnar_group_by_simple() {
463        // Test simple GROUP BY with one group column
464        // SELECT status, SUM(amount) FROM test GROUP BY status
465        let rows = vec![
466            Row::new(vec![SqlValue::Varchar("A".to_string()), SqlValue::Double(100.0)]),
467            Row::new(vec![SqlValue::Varchar("B".to_string()), SqlValue::Double(200.0)]),
468            Row::new(vec![SqlValue::Varchar("A".to_string()), SqlValue::Double(150.0)]),
469            Row::new(vec![SqlValue::Varchar("B".to_string()), SqlValue::Double(50.0)]),
470        ];
471
472        let group_cols = vec![0]; // Group by status
473        let agg_cols = vec![(1, AggregateOp::Sum)]; // SUM(amount)
474
475        let result = columnar_group_by(&rows, &group_cols, &agg_cols, None).unwrap();
476
477        // Should have 2 groups: A and B
478        assert_eq!(result.len(), 2);
479
480        // Sort results by group key for deterministic testing
481        let mut sorted = result;
482        sorted.sort_by(|a, b| {
483            let a_key = a.get(0).unwrap();
484            let b_key = b.get(0).unwrap();
485            a_key.partial_cmp(b_key).unwrap()
486        });
487
488        // Check group A: SUM = 250.0
489        assert_eq!(sorted[0].get(0), Some(&SqlValue::Varchar("A".to_string())));
490        assert!(matches!(sorted[0].get(1), Some(&SqlValue::Double(sum)) if (sum - 250.0).abs() < 0.001));
491
492        // Check group B: SUM = 250.0
493        assert_eq!(sorted[1].get(0), Some(&SqlValue::Varchar("B".to_string())));
494        assert!(matches!(sorted[1].get(1), Some(&SqlValue::Double(sum)) if (sum - 250.0).abs() < 0.001));
495    }
496
497    #[test]
498    fn test_columnar_group_by_multiple_group_keys() {
499        // Test GROUP BY with multiple columns
500        // SELECT status, category, COUNT(*) FROM test GROUP BY status, category
501        let rows = vec![
502            Row::new(vec![SqlValue::Varchar("A".to_string()), SqlValue::Integer(1)]),
503            Row::new(vec![SqlValue::Varchar("A".to_string()), SqlValue::Integer(2)]),
504            Row::new(vec![SqlValue::Varchar("B".to_string()), SqlValue::Integer(1)]),
505            Row::new(vec![SqlValue::Varchar("A".to_string()), SqlValue::Integer(1)]),
506            Row::new(vec![SqlValue::Varchar("B".to_string()), SqlValue::Integer(2)]),
507        ];
508
509        let group_cols = vec![0, 1]; // Group by status, category
510        let agg_cols = vec![(0, AggregateOp::Count)]; // COUNT(*)
511
512        let result = columnar_group_by(&rows, &group_cols, &agg_cols, None).unwrap();
513
514        // Should have 4 groups: (A,1), (A,2), (B,1), (B,2)
515        assert_eq!(result.len(), 4);
516
517        // Verify we have correct counts
518        for row in &result {
519            let status = row.get(0).unwrap();
520            let category = row.get(1).unwrap();
521            let count = row.get(2).unwrap();
522
523            match (status, category) {
524                (SqlValue::Varchar(s), SqlValue::Integer(1)) if s == "A" => {
525                    assert_eq!(count, &SqlValue::Integer(2)); // Two rows with A,1
526                }
527                (SqlValue::Varchar(s), SqlValue::Integer(2)) if s == "A" => {
528                    assert_eq!(count, &SqlValue::Integer(1)); // One row with A,2
529                }
530                (SqlValue::Varchar(s), SqlValue::Integer(1)) if s == "B" => {
531                    assert_eq!(count, &SqlValue::Integer(1)); // One row with B,1
532                }
533                (SqlValue::Varchar(s), SqlValue::Integer(2)) if s == "B" => {
534                    assert_eq!(count, &SqlValue::Integer(1)); // One row with B,2
535                }
536                _ => panic!("Unexpected group key: {:?}, {:?}", status, category),
537            }
538        }
539    }
540
541    #[test]
542    fn test_columnar_group_by_multiple_aggregates() {
543        // Test GROUP BY with multiple aggregate functions
544        // SELECT category, SUM(price), AVG(quantity), COUNT(*) FROM test GROUP BY category
545        let rows = vec![
546            Row::new(vec![SqlValue::Integer(1), SqlValue::Double(100.0), SqlValue::Integer(10)]),
547            Row::new(vec![SqlValue::Integer(2), SqlValue::Double(200.0), SqlValue::Integer(20)]),
548            Row::new(vec![SqlValue::Integer(1), SqlValue::Double(150.0), SqlValue::Integer(15)]),
549        ];
550
551        let group_cols = vec![0]; // Group by category
552        let agg_cols = vec![
553            (1, AggregateOp::Sum),   // SUM(price)
554            (2, AggregateOp::Avg),   // AVG(quantity)
555            (0, AggregateOp::Count), // COUNT(*)
556        ];
557
558        let result = columnar_group_by(&rows, &group_cols, &agg_cols, None).unwrap();
559
560        // Should have 2 groups
561        assert_eq!(result.len(), 2);
562
563        // Sort by category for deterministic testing
564        let mut sorted = result;
565        sorted.sort_by(|a, b| {
566            let a_key = a.get(0).unwrap();
567            let b_key = b.get(0).unwrap();
568            a_key.partial_cmp(b_key).unwrap()
569        });
570
571        // Group 1: SUM=250, AVG=12.5, COUNT=2
572        assert_eq!(sorted[0].get(0), Some(&SqlValue::Integer(1)));
573        assert!(matches!(sorted[0].get(1), Some(&SqlValue::Double(sum)) if (sum - 250.0).abs() < 0.001));
574        assert!(matches!(sorted[0].get(2), Some(&SqlValue::Double(avg)) if (avg - 12.5).abs() < 0.001));
575        assert_eq!(sorted[0].get(3), Some(&SqlValue::Integer(2)));
576
577        // Group 2: SUM=200, AVG=20.0, COUNT=1
578        assert_eq!(sorted[1].get(0), Some(&SqlValue::Integer(2)));
579        assert!(matches!(sorted[1].get(1), Some(&SqlValue::Double(sum)) if (sum - 200.0).abs() < 0.001));
580        assert!(matches!(sorted[1].get(2), Some(&SqlValue::Double(avg)) if (avg - 20.0).abs() < 0.001));
581        assert_eq!(sorted[1].get(3), Some(&SqlValue::Integer(1)));
582    }
583
584    #[test]
585    fn test_columnar_group_by_with_filter() {
586        // Test GROUP BY with pre-filtering
587        // SELECT status, SUM(amount) FROM test WHERE amount > 100 GROUP BY status
588        let rows = vec![
589            Row::new(vec![SqlValue::Varchar("A".to_string()), SqlValue::Double(100.0)]),
590            Row::new(vec![SqlValue::Varchar("B".to_string()), SqlValue::Double(200.0)]),
591            Row::new(vec![SqlValue::Varchar("A".to_string()), SqlValue::Double(150.0)]),
592            Row::new(vec![SqlValue::Varchar("B".to_string()), SqlValue::Double(50.0)]),
593        ];
594
595        // Filter: amount > 100 (rows 1 and 2)
596        let filter = vec![false, true, true, false];
597
598        let group_cols = vec![0]; // Group by status
599        let agg_cols = vec![(1, AggregateOp::Sum)]; // SUM(amount)
600
601        let result = columnar_group_by(&rows, &group_cols, &agg_cols, Some(&filter)).unwrap();
602
603        // Should have 2 groups (only rows passing filter)
604        assert_eq!(result.len(), 2);
605
606        // Sort results by group key
607        let mut sorted = result;
608        sorted.sort_by(|a, b| {
609            let a_key = a.get(0).unwrap();
610            let b_key = b.get(0).unwrap();
611            a_key.partial_cmp(b_key).unwrap()
612        });
613
614        // Check group A: only row 2 (150.0) passes filter
615        assert_eq!(sorted[0].get(0), Some(&SqlValue::Varchar("A".to_string())));
616        assert!(matches!(sorted[0].get(1), Some(&SqlValue::Double(sum)) if (sum - 150.0).abs() < 0.001));
617
618        // Check group B: only row 1 (200.0) passes filter
619        assert_eq!(sorted[1].get(0), Some(&SqlValue::Varchar("B".to_string())));
620        assert!(matches!(sorted[1].get(1), Some(&SqlValue::Double(sum)) if (sum - 200.0).abs() < 0.001));
621    }
622
623    #[test]
624    fn test_columnar_group_by_empty_input() {
625        let rows: Vec<Row> = vec![];
626        let group_cols = vec![0];
627        let agg_cols = vec![(1, AggregateOp::Sum)];
628
629        let result = columnar_group_by(&rows, &group_cols, &agg_cols, None).unwrap();
630
631        // Should return empty result for empty input
632        assert_eq!(result.len(), 0);
633    }
634
635    #[test]
636    fn test_columnar_group_by_null_in_group_key() {
637        // Test that NULL values in group keys are handled correctly
638        let rows = vec![
639            Row::new(vec![SqlValue::Varchar("A".to_string()), SqlValue::Double(100.0)]),
640            Row::new(vec![SqlValue::Null, SqlValue::Double(200.0)]),
641            Row::new(vec![SqlValue::Varchar("A".to_string()), SqlValue::Double(150.0)]),
642            Row::new(vec![SqlValue::Null, SqlValue::Double(50.0)]),
643        ];
644
645        let group_cols = vec![0]; // Group by first column
646        let agg_cols = vec![(1, AggregateOp::Sum)]; // SUM
647
648        let result = columnar_group_by(&rows, &group_cols, &agg_cols, None).unwrap();
649
650        // Should have 2 groups: "A" and NULL
651        assert_eq!(result.len(), 2);
652
653        // Find the groups
654        let a_group = result.iter().find(|r| matches!(r.get(0), Some(SqlValue::Varchar(s)) if s == "A"));
655        let null_group = result.iter().find(|r| matches!(r.get(0), Some(SqlValue::Null)));
656
657        assert!(a_group.is_some());
658        assert!(null_group.is_some());
659
660        // Check "A" group: 100 + 150 = 250
661        assert!(matches!(a_group.unwrap().get(1), Some(&SqlValue::Double(sum)) if (sum - 250.0).abs() < 0.001));
662
663        // Check NULL group: 200 + 50 = 250
664        assert!(matches!(null_group.unwrap().get(1), Some(&SqlValue::Double(sum)) if (sum - 250.0).abs() < 0.001));
665    }
666}